whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -0,0 +1,2503 @@
1
+ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
2
+ #pragma clang diagnostic ignored "-Wunused-function"
3
+ #pragma clang diagnostic ignored "-Wunused-variable"
4
+ #pragma clang diagnostic ignored "-Wunused-but-set-variable"
5
+
6
+ #ifdef HTP_DEBUG
7
+ # define FARF_HIGH 1
8
+ #endif
9
+
10
+ #include <HAP_farf.h>
11
+ #include <HAP_mem.h>
12
+ #include <HAP_perf.h>
13
+ #include <HAP_ps.h>
14
+ #include <hexagon_protos.h>
15
+ #include <hexagon_types.h>
16
+ #include <math.h>
17
+ #include <qurt_thread.h>
18
+ #include <string.h>
19
+
20
+ #define GGML_COMMON_DECL_C
21
+ #include "ggml-common.h"
22
+ #include "htp-ctx.h"
23
+ #include "htp-dma.h"
24
+ #include "htp-msg.h"
25
+ #include "htp-ops.h"
26
+ #include "hvx-utils.h"
27
+ #include "ops-utils.h"
28
+
29
+ #define MM_SPAD_SRC0_NROWS 16
30
+ #define MM_SPAD_SRC1_NROWS 16
31
+ #define MM_SPAD_DST_NROWS 2
32
+
33
+ struct htp_matmul_type {
34
+ const char * type;
35
+ void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
36
+ void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
37
+ };
38
+
39
+ typedef struct {
40
+ HVX_Vector v[2];
41
+ } HVX_Vector_x2;
42
+
43
+ typedef struct {
44
+ HVX_Vector v[4];
45
+ } HVX_Vector_x4;
46
+
47
+ typedef struct {
48
+ HVX_Vector v[8];
49
+ } HVX_Vector_x8;
50
+
51
+ // vdelta control to replicate first 4x fp32 values across lanes
52
+ static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
53
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
54
+ 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
55
+ 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
56
+ 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
57
+ 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
58
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
59
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
60
+ };
61
+
62
+ // vdelta control to replicate and interleave first 8x fp32 values across lanes
63
+ static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = {
64
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
65
+ 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
66
+ 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
67
+ 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
68
+ 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
69
+ 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
70
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
71
+ };
72
+
73
+ // vdelta control to replicate first fp32 value across all elements
74
+ static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
75
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
76
+ 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
77
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
78
+ 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
79
+ 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
80
+ 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
81
+ 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
82
+ };
83
+
84
+ // vdelta control to replicate first fp16 value across all elements
85
+ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
86
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
87
+ 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
88
+ 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
89
+ 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
90
+ 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
91
+ 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
92
+ 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
93
+ };
94
+
95
+ // vdelta control to replicate first fp16 value across all elements
96
+ static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
97
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
98
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
99
+ 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
100
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
101
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
102
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
103
+ 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
104
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
105
+ };
106
+
107
+ // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
108
+ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
109
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
110
+ 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,
111
+ 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,
112
+ 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,
113
+ 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,
114
+ 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,
115
+ 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
116
+ };
117
+
118
+ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
119
+ 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
120
+ 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
121
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
122
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
123
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
124
+ };
125
+
126
+ // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
127
+
128
+ static inline size_t q8x4x2_row_size(uint32_t ne) {
129
+ // ensures perfect alignment of quants and full row
130
+ const uint32_t qk = QK_Q8_0x4x2;
131
+ const uint32_t nb = (ne + qk - 1) / qk;
132
+ return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128);
133
+ }
134
+
135
+ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
136
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
137
+
138
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
139
+ HVX_Vector v2_3 = vptr[1]; // ...
140
+ HVX_Vector v4_5 = vptr[2]; // ...
141
+ HVX_Vector v6_7 = vptr[3]; // ...
142
+
143
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
144
+
145
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
146
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
147
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
148
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
149
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
150
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
151
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
152
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
153
+
154
+ // Convert uint4 to int4 (i.e. x - 8)
155
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
156
+ v0 = Q6_Vb_vsub_VbVb(v0, i8);
157
+ v1 = Q6_Vb_vsub_VbVb(v1, i8);
158
+ v2 = Q6_Vb_vsub_VbVb(v2, i8);
159
+ v3 = Q6_Vb_vsub_VbVb(v3, i8);
160
+ v4 = Q6_Vb_vsub_VbVb(v4, i8);
161
+ v5 = Q6_Vb_vsub_VbVb(v5, i8);
162
+ v6 = Q6_Vb_vsub_VbVb(v6, i8);
163
+ v7 = Q6_Vb_vsub_VbVb(v7, i8);
164
+
165
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
166
+ return r;
167
+ }
168
+
169
+ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
170
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
171
+
172
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
173
+ HVX_Vector v2_3 = vptr[1]; // ...
174
+ HVX_Vector v4_5 = vptr[2]; // ...
175
+ HVX_Vector v6_7 = vptr[3]; // ...
176
+
177
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
178
+
179
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
180
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
181
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
182
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
183
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
184
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
185
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
186
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
187
+
188
+ HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
189
+ v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
190
+ v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
191
+ v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
192
+ v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
193
+ v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
194
+ v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
195
+ v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
196
+ v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
197
+
198
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
199
+ return r;
200
+ }
201
+
202
+ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
203
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
204
+
205
+ HVX_Vector v0 = vptr[0]; // first 128 vals
206
+ HVX_Vector v1 = vptr[1]; // ...
207
+ HVX_Vector v2 = vptr[2]; // ...
208
+ HVX_Vector v3 = vptr[3]; // ...
209
+ HVX_Vector v4 = vptr[4]; // ...
210
+ HVX_Vector v5 = vptr[5]; // ...
211
+ HVX_Vector v6 = vptr[6]; // ...
212
+ HVX_Vector v7 = vptr[7]; // ...
213
+
214
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
215
+ return r;
216
+ }
217
+
218
+ static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
219
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
220
+
221
+ HVX_Vector v0 = vptr[0]; // first 64 vals
222
+ HVX_Vector v1 = vptr[1]; // second 64 vals
223
+ HVX_Vector v2 = vptr[2]; // third 64 vals
224
+ HVX_Vector v3 = vptr[3]; // forth 64 vals
225
+
226
+ HVX_Vector_x4 r = { v0, v1, v2, v3 };
227
+ return r;
228
+ }
229
+
230
+ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
231
+ const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
232
+
233
+ HVX_VectorPair v0 = vptr[0]; // first 64 vals
234
+ HVX_VectorPair v1 = vptr[1]; // second 64 vals
235
+ HVX_VectorPair v2 = vptr[2]; // third 64 vals
236
+ HVX_VectorPair v3 = vptr[3]; // forth 64 vals
237
+
238
+ HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
239
+ HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
240
+ HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
241
+ HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
242
+ HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
243
+ HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
244
+ HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
245
+ HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
246
+
247
+ HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
248
+ HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
249
+ HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
250
+ HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
251
+
252
+ // vcombine does a shuffle, use vdeal to undo
253
+
254
+ HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
255
+ return r;
256
+ }
257
+
258
+ // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
259
+ // Accumulate each block into a single int32 value.
260
+ // Return a single HVX vector with 32x int32 accumulators.
261
+ // This version is parameterized to support less than 1024 elements.
262
+ // if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
263
+
264
+ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
265
+ HVX_Vector r0 = Q6_V_vsplat_R(0);
266
+ HVX_Vector r1 = Q6_V_vsplat_R(0);
267
+ HVX_Vector r2 = Q6_V_vsplat_R(0);
268
+ HVX_Vector r3 = Q6_V_vsplat_R(0);
269
+ HVX_Vector r4 = Q6_V_vsplat_R(0);
270
+ HVX_Vector r5 = Q6_V_vsplat_R(0);
271
+ HVX_Vector r6 = Q6_V_vsplat_R(0);
272
+ HVX_Vector r7 = Q6_V_vsplat_R(0);
273
+
274
+ HVX_VectorPair p3;
275
+ HVX_VectorPair p2;
276
+ HVX_VectorPair p1;
277
+ HVX_VectorPair p0;
278
+
279
+ if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
280
+ if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
281
+ if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
282
+ if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
283
+ if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
284
+ if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
285
+ if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
286
+ if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
287
+
288
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
289
+ if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
290
+ if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
291
+ if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
292
+
293
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
294
+ if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
295
+ if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
296
+ if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
297
+
298
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
299
+ if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
300
+
301
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
302
+ if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
303
+
304
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
305
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
306
+
307
+ return r0;
308
+ }
309
+
310
+ static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
311
+ return hvx_vec_rmpy_x8_n(x, y, 1024);
312
+ }
313
+
314
+ // Handle most common cases of tensors not multiple of 1024.
315
+ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
316
+ if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
317
+ if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
318
+ if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
319
+ return hvx_vec_rmpy_x8_n(x, y, 1024);
320
+ }
321
+
322
+ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
323
+ assert(n % 32 == 0); // min sub-block size
324
+ assert((unsigned long) vx % 128 == 0);
325
+ assert((unsigned long) vy % 128 == 0);
326
+
327
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
328
+
329
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
330
+ const uint32_t x_qblk_size = qk / 2; // int4
331
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
332
+
333
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
334
+ const uint32_t y_qblk_size = qk; // int8
335
+ const uint32_t y_qrow_size = n; // int8 (not padded)
336
+
337
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
338
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
339
+
340
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
341
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
342
+
343
+ // Row sum (qf32)
344
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
345
+
346
+ // Multiply and accumulate into int32.
347
+ // Compute combined scale (fp32).
348
+ // Apply scale to acc and accumulate into the row sum (qf32).
349
+
350
+ const uint32_t nb = n / qk; // num full blocks
351
+ const uint32_t nloe = n % qk; // num leftover elemements
352
+
353
+ uint32_t i = 0;
354
+ for (; i < nb; i++) {
355
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
356
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
357
+
358
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
359
+
360
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
361
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
362
+
363
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
364
+
365
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
366
+
367
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
368
+ }
369
+
370
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
371
+ if (nloe) {
372
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
373
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
374
+
375
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
376
+
377
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
378
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
379
+
380
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
381
+
382
+ // Zero out unused scales
383
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
384
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
385
+
386
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
387
+
388
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
389
+ }
390
+
391
+ // Reduce and convert into fp32
392
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
393
+
394
+ hvx_vec_store_u(&s[0], 4, r0_sum);
395
+ }
396
+
397
+ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
398
+ float * restrict s,
399
+ const void * restrict vx,
400
+ uint32_t vx_row_size,
401
+ const void * restrict vy) {
402
+ assert(n % 32 == 0); // min sub-block size
403
+ assert((unsigned long) vx % 128 == 0);
404
+ assert((unsigned long) vy % 128 == 0);
405
+
406
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
407
+
408
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
409
+ const uint32_t x_qblk_size = qk / 2; // int4
410
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
411
+
412
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
413
+ const uint32_t y_qblk_size = qk; // int8
414
+ const uint32_t y_qrow_size = n; // int8 (not padded)
415
+
416
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
417
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
418
+
419
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
420
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
421
+
422
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
423
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
424
+
425
+ // Row sum (qf32)
426
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
427
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
428
+
429
+ // Multiply and accumulate into int32.
430
+ // Compute combined scale (fp32).
431
+ // Apply scale to acc and accumulate into the row sum (qf32).
432
+
433
+ const uint32_t nb = n / qk; // num full blocks
434
+ const uint32_t nloe = n % qk; // num leftover elemements
435
+
436
+ uint32_t i = 0;
437
+ for (; i < nb; i++) {
438
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
439
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
440
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
441
+
442
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
443
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
444
+
445
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
446
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
447
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
448
+
449
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
450
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
451
+
452
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
453
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
454
+
455
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
456
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
457
+ }
458
+
459
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
460
+ if (nloe) {
461
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
462
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
463
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
464
+
465
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
466
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
467
+
468
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
469
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
470
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
471
+
472
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
473
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
474
+
475
+ // Zero out unused scales
476
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
477
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
478
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
479
+
480
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
481
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
482
+
483
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
484
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
485
+ }
486
+
487
+ // Convert into fp32 and reduce
488
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
489
+ r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
490
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
491
+
492
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
493
+ }
494
+
495
+ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
496
+ assert(n % 32 == 0); // min sub-block size
497
+ assert((unsigned long) vx % 128 == 0);
498
+ assert((unsigned long) vy % 128 == 0);
499
+
500
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
501
+
502
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
503
+ const uint32_t x_qblk_size = qk; // int8
504
+ const uint32_t x_qrow_size = n; // int8 (not padded)
505
+
506
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
507
+ const uint32_t y_qblk_size = qk; // int8
508
+ const uint32_t y_qrow_size = n; // int8 (not padded)
509
+
510
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
511
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
512
+
513
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
514
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
515
+
516
+ // Row sum (qf32)
517
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
518
+
519
+ // Multiply and accumulate into int32.
520
+ // Compute combined scale (fp32).
521
+ // Apply scale to acc and accumulate into the row sum (qf32).
522
+
523
+ const uint32_t nb = n / qk; // num full blocks
524
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
525
+
526
+ uint32_t i = 0;
527
+ for (; i < nb; i++) {
528
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
529
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
530
+
531
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
532
+
533
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
534
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
535
+
536
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
537
+
538
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
539
+
540
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
541
+ }
542
+
543
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
544
+ if (nloe) {
545
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
546
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
547
+
548
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
549
+
550
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
551
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
552
+
553
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
554
+
555
+ // Zero out unused scales
556
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
557
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
558
+
559
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
560
+
561
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
562
+ }
563
+
564
+ // Reduce and convert into fp32
565
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
566
+
567
+ hvx_vec_store_u(&s[0], 4, r0_sum);
568
+ }
569
+
570
+ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
571
+ float * restrict s,
572
+ const void * restrict vx,
573
+ uint32_t vx_row_size,
574
+ const void * restrict vy) {
575
+ assert(n % 32 == 0); // min sub-block size
576
+ assert((unsigned long) vx % 128 == 0);
577
+ assert((unsigned long) vy % 128 == 0);
578
+
579
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
580
+
581
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
582
+ const uint32_t x_qblk_size = qk; // int8
583
+ const uint32_t x_qrow_size = n; // int8 (not padded)
584
+
585
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
586
+ const uint32_t y_qblk_size = qk; // int8
587
+ const uint32_t y_qrow_size = n; // int8 (not padded)
588
+
589
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
590
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
591
+
592
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
593
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
594
+
595
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
596
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
597
+
598
+ // Row sum (qf32)
599
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
600
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
601
+
602
+ // Multiply and accumulate into int32.
603
+ // Compute combined scale (fp32).
604
+ // Apply scale to acc and accumulate into the row sum (qf32).
605
+
606
+ const uint32_t nb = n / qk; // num full blocks
607
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
608
+
609
+ uint32_t i = 0;
610
+ for (; i < nb; i++) {
611
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
612
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
613
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
614
+
615
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
616
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
617
+
618
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
619
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
620
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
621
+
622
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
623
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
624
+
625
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
626
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
627
+
628
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
629
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
630
+ }
631
+
632
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
633
+ if (nloe) {
634
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
635
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
636
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
637
+
638
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
639
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
640
+
641
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
642
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
643
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
644
+
645
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
646
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
647
+
648
+ // Zero out unused scales
649
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
650
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
651
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
652
+
653
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
654
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
655
+
656
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
657
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
658
+ }
659
+
660
+ // Convert into fp32 and reduce
661
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
662
+ r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
663
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
664
+
665
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
666
+ }
667
+
668
+ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
669
+ float * restrict s,
670
+ const void * restrict vx,
671
+ const void * restrict vy) {
672
+ assert(n % 32 == 0); // min sub-block size
673
+ assert((unsigned long) vx % 128 == 0);
674
+ assert((unsigned long) vy % 128 == 0);
675
+
676
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
677
+
678
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
679
+ const uint32_t x_qblk_size = qk / 2; // fp4
680
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
681
+
682
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
683
+ const uint32_t y_qblk_size = qk; // int8
684
+ const uint32_t y_qrow_size = n; // int8 (not padded)
685
+
686
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
687
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
688
+
689
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
690
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
691
+
692
+ // Row sum (qf32)
693
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
694
+
695
+ // Multiply and accumulate into int32.
696
+ // Compute combined scale (fp32).
697
+ // Apply scale to acc and accumulate into the row sum (qf32).
698
+
699
+ const uint32_t nb = n / qk; // num full blocks
700
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
701
+
702
+ uint32_t i = 0;
703
+ for (; i < nb; i++) {
704
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
705
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
706
+
707
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
708
+
709
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
710
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
711
+
712
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
713
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
714
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
715
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
716
+
717
+ // Convert rX_d scales from e8m0 to fp32
718
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
719
+ // Left shift with zero fill to create FP32
720
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
721
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
722
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
723
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
724
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
725
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
726
+
727
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
728
+
729
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
730
+
731
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
732
+ }
733
+
734
+ // Process leftovers
735
+ if (nloe) {
736
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
737
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
738
+
739
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
740
+
741
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
742
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
743
+
744
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
745
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
746
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
747
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
748
+
749
+ // Convert rX_d scales from e8m0 to fp32
750
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
751
+ // Left shift with zero fill to create FP32
752
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
753
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
754
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
755
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
756
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
757
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
758
+
759
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
760
+
761
+ // Zero-out unused scales
762
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
763
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
764
+
765
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
766
+
767
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
768
+ }
769
+
770
+ // Reduce and convert into fp32
771
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
772
+
773
+ hvx_vec_store_u(&s[0], 4, r0_sum);
774
+ }
775
+
776
+ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
777
+ float * restrict s,
778
+ const void * restrict vx,
779
+ uint32_t vx_row_size,
780
+ const void * restrict vy) {
781
+ assert(n % 32 == 0); // min sub-block size
782
+ assert((unsigned long) vx % 128 == 0);
783
+ assert((unsigned long) vy % 128 == 0);
784
+
785
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
786
+
787
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
788
+ const uint32_t x_qblk_size = qk / 2; // fp4
789
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
790
+
791
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
792
+ const uint32_t y_qblk_size = qk; // int8
793
+ const uint32_t y_qrow_size = n; // int8 (not padded)
794
+
795
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
796
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
797
+
798
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
799
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
800
+
801
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
802
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
803
+
804
+ // Row sum (qf32)
805
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
806
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
807
+
808
+ // Multiply and accumulate into int32.
809
+ // Compute combined scale (fp32).
810
+ // Apply scale to acc and accumulate into the row sum (qf32).
811
+
812
+ const uint32_t nb = n / qk; // num full blocks
813
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
814
+
815
+ uint32_t i = 0;
816
+ for (; i < nb; i++) {
817
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
818
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
819
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
820
+
821
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
822
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
823
+
824
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
825
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
826
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
827
+
828
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
829
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
830
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
831
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
832
+
833
+ // Convert rX_d scales from e8m0 to fp32
834
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
835
+ // Left shift with zero fill to create FP32
836
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
837
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
838
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
839
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
840
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
841
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
842
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
843
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
844
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
845
+
846
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
847
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
848
+
849
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
850
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
851
+
852
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
853
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
854
+ }
855
+
856
+ // Process leftovers
857
+ if (nloe) {
858
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
859
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
860
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
861
+
862
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
863
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
864
+
865
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
866
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
867
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
868
+
869
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
870
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
871
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
872
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
873
+
874
+ // Convert rX_d scales from e8m0 to fp32
875
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
876
+ // Left shift with zero fill to create FP32
877
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
878
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
879
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
880
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
881
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
882
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
883
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
884
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
885
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
886
+
887
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
888
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
889
+
890
+ // Zero-out unused scales
891
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
892
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
893
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
894
+
895
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
896
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
897
+
898
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
899
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
900
+ }
901
+
902
+ // Convert into fp32 and reduce
903
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
904
+ r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
905
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
906
+
907
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
908
+ }
909
+
910
+ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
911
+ const HVX_Vector * restrict x = (const HVX_Vector *) vx;
912
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
913
+
914
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
915
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
916
+
917
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
918
+
919
+ uint32_t i = 0;
920
+
921
+ #pragma unroll(4)
922
+ for (i = 0; i < nvec; i++) {
923
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
924
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
925
+ }
926
+
927
+ if (nloe) {
928
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
929
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
930
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
931
+
932
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
933
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
934
+ }
935
+
936
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
937
+ hvx_vec_store_u(&s[0], 4, rsum);
938
+ }
939
+
940
+ static void vec_dot_f16_f16_aa_rx2(const int n,
941
+ float * restrict s,
942
+ const void * restrict vx,
943
+ uint32_t vx_row_size,
944
+ const void * restrict vy) {
945
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
946
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
947
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
948
+
949
+ uint32_t nvec = n / VLEN_FP16;
950
+ uint32_t nloe = n % VLEN_FP16;
951
+
952
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
953
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
954
+
955
+ uint32_t i = 0;
956
+
957
+ #pragma unroll(2)
958
+ for (i = 0; i < nvec; i++) {
959
+ HVX_Vector y_hf = y[i];
960
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
961
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
962
+
963
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
964
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
965
+ }
966
+
967
+ if (nloe) {
968
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
969
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
970
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
971
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
972
+
973
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
974
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
975
+
976
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
977
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
978
+ }
979
+
980
+ rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
981
+ rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
982
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
983
+
984
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
985
+ }
986
+
987
+ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
988
+ const HVX_UVector * restrict x = (const HVX_UVector *) vx;
989
+ const HVX_UVector * restrict y = (const HVX_UVector *) vy;
990
+
991
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
992
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
993
+
994
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
995
+
996
+ uint32_t i = 0;
997
+
998
+ #pragma unroll(4)
999
+ for (i = 0; i < nvec; i++) {
1000
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
1001
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1002
+ }
1003
+
1004
+ if (nloe) {
1005
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1006
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
1007
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
1008
+
1009
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1010
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1011
+ }
1012
+
1013
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
1014
+ hvx_vec_store_u(&s[0], 4, rsum);
1015
+ }
1016
+
1017
+ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
1018
+ const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
1019
+ const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
1020
+
1021
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
1022
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
1023
+
1024
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
1025
+
1026
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
1027
+
1028
+ uint32_t i = 0;
1029
+
1030
+ #pragma unroll(2)
1031
+ for (i = 0; i < nvec; i++) {
1032
+ // Load y (fp32) and convert into fp16
1033
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
1034
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
1035
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
1036
+
1037
+ // Load x (fp16)
1038
+ HVX_Vector x_hf = vx[i];
1039
+
1040
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1041
+
1042
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1043
+ }
1044
+
1045
+ if (nloe) {
1046
+ // Load y (fp32) and convert into fp16
1047
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
1048
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
1049
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
1050
+
1051
+ // Load x (fp16)
1052
+ HVX_Vector x_hf = vx[i];
1053
+
1054
+ // Zero-out unused elements
1055
+ // Note that we need to clear both x and y because they may contain NANs
1056
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
1057
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
1058
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
1059
+
1060
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
1061
+
1062
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
1063
+ }
1064
+
1065
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
1066
+ hvx_vec_store_u(&s[0], 4, rsum);
1067
+ }
1068
+
1069
+ #define htp_matmul_tensors_preamble \
1070
+ struct htp_tensor * restrict src0 = &octx->src0; \
1071
+ struct htp_tensor * restrict src1 = &octx->src1; \
1072
+ struct htp_tensor * restrict src2 = &octx->src2; \
1073
+ struct htp_tensor * restrict dst = &octx->dst; \
1074
+ struct htp_spad * restrict src0_spad = &octx->src0_spad; \
1075
+ struct htp_spad * restrict src1_spad = &octx->src1_spad; \
1076
+ struct htp_spad * restrict dst_spad = &octx->dst_spad; \
1077
+ \
1078
+ const uint32_t ne00 = src0->ne[0]; \
1079
+ const uint32_t ne01 = src0->ne[1]; \
1080
+ const uint32_t ne02 = src0->ne[2]; \
1081
+ const uint32_t ne03 = src0->ne[3]; \
1082
+ \
1083
+ const uint32_t ne10 = src1->ne[0]; \
1084
+ const uint32_t ne11 = src1->ne[1]; \
1085
+ const uint32_t ne12 = src1->ne[2]; \
1086
+ const uint32_t ne13 = src1->ne[3]; \
1087
+ \
1088
+ const uint32_t ne20 = src2->ne[0]; \
1089
+ const uint32_t ne21 = src2->ne[1]; \
1090
+ const uint32_t ne22 = src2->ne[2]; \
1091
+ const uint32_t ne23 = src2->ne[3]; \
1092
+ \
1093
+ const uint32_t ne0 = dst->ne[0]; \
1094
+ const uint32_t ne1 = dst->ne[1]; \
1095
+ const uint32_t ne2 = dst->ne[2]; \
1096
+ const uint32_t ne3 = dst->ne[3]; \
1097
+ \
1098
+ const uint32_t nb00 = src0->nb[0]; \
1099
+ const uint32_t nb01 = src0->nb[1]; \
1100
+ const uint32_t nb02 = src0->nb[2]; \
1101
+ const uint32_t nb03 = src0->nb[3]; \
1102
+ \
1103
+ const uint32_t nb10 = src1->nb[0]; \
1104
+ const uint32_t nb11 = src1->nb[1]; \
1105
+ const uint32_t nb12 = src1->nb[2]; \
1106
+ const uint32_t nb13 = src1->nb[3]; \
1107
+ \
1108
+ const uint32_t nb0 = dst->nb[0]; \
1109
+ const uint32_t nb1 = dst->nb[1]; \
1110
+ const uint32_t nb2 = dst->nb[2]; \
1111
+ const uint32_t nb3 = dst->nb[3];
1112
+
1113
+ #define htp_matmul_preamble \
1114
+ htp_matmul_tensors_preamble; \
1115
+ dma_queue *dma_queue = octx->ctx->dma[ith]; \
1116
+ uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
1117
+
1118
+ // *** matmul with support for 4d tensors and full broadcasting
1119
+
1120
+ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1121
+ htp_matmul_preamble;
1122
+
1123
+ uint64_t t1, t2;
1124
+ t1 = HAP_perf_get_qtimer_count();
1125
+
1126
+ assert(ne12 % ne02 == 0);
1127
+ assert(ne13 % ne03 == 0);
1128
+
1129
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
1130
+ const uint32_t nr0 = ne0;
1131
+
1132
+ // This is the size of the rest of the dimensions of the result
1133
+ const uint32_t nr1 = ne1 * ne2 * ne3;
1134
+
1135
+ // distribute the thread work across the inner or outer loop based on which one is larger
1136
+ uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
1137
+ uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
1138
+
1139
+ // The number of elements in each chunk
1140
+ const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1141
+ const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
1142
+
1143
+ uint32_t current_chunk = ith;
1144
+
1145
+ const uint32_t ith0 = current_chunk % nchunk0;
1146
+ const uint32_t ith1 = current_chunk / nchunk0;
1147
+
1148
+ const uint32_t ir0_start = dr0 * ith0;
1149
+ const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
1150
+
1151
+ const uint32_t ir1_start = dr1 * ith1;
1152
+ const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
1153
+
1154
+ // no work for this thread
1155
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
1156
+ return;
1157
+ }
1158
+
1159
+ // block-tiling attempt
1160
+ const uint32_t blck_0 = 64;
1161
+ const uint32_t blck_1 = 64;
1162
+
1163
+ for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
1164
+ for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
1165
+ for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
1166
+ const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
1167
+ const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
1168
+ const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
1169
+
1170
+ // broadcast src0 into src1
1171
+ const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
1172
+ const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
1173
+
1174
+ const uint32_t i1 = i11;
1175
+ const uint32_t i2 = i12;
1176
+ const uint32_t i3 = i13;
1177
+
1178
+ const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
1179
+ const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
1180
+ float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
1181
+
1182
+ const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
1183
+ for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
1184
+ const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
1185
+ mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
1186
+ }
1187
+ }
1188
+ }
1189
+ }
1190
+
1191
+ t2 = HAP_perf_get_qtimer_count();
1192
+
1193
+ FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
1194
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
1195
+ src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1196
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1197
+ }
1198
+
1199
+ // src1 tensor is already in VTCM spad
1200
+ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1201
+ htp_matmul_preamble;
1202
+
1203
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
1204
+ const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
1205
+
1206
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1207
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1208
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1209
+
1210
+ // no work for this thread
1211
+ if (src0_start_row >= src0_end_row) {
1212
+ return;
1213
+ }
1214
+
1215
+ const size_t dst_row_size = nb1;
1216
+ const size_t src0_row_size = nb01;
1217
+ const size_t src1_row_size = nb11;
1218
+
1219
+ const size_t src0_stride = src0_spad->stride;
1220
+ const size_t src1_stride = src1_spad->stride;
1221
+
1222
+ // Per-thread VTCM scratchpads for all tensors
1223
+ // Note that the entire src1 tensor is already in VTCM
1224
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1225
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1226
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1227
+ uint8_t * restrict src1_data = src1_spad->data;
1228
+
1229
+ volatile uint64_t t1, t2;
1230
+ t1 = HAP_perf_get_qtimer_count();
1231
+
1232
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1233
+
1234
+ // Prefill spad with src0 rows
1235
+ #pragma unroll(4)
1236
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1237
+ const int is0 = (ir0 - src0_start_row);
1238
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
1239
+ break;
1240
+ }
1241
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1242
+ src0_stride, src0_row_size, 2);
1243
+ }
1244
+
1245
+ // Process src0 rows
1246
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1247
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1248
+
1249
+ #pragma unroll(2)
1250
+ for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1251
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1252
+ float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1253
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
1254
+ }
1255
+
1256
+ // Prefetch next (n + spad_nrows) row
1257
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1258
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1259
+ if (pr0 < src0_end_row_x2) {
1260
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
1261
+ src0_stride, src0_row_size, 2);
1262
+ }
1263
+ }
1264
+
1265
+ // Process the last row (if any)
1266
+ if (src0_end_row != src0_end_row_x2) {
1267
+ uint32_t ir0 = src0_end_row_x2;
1268
+ const int is0 = (ir0 - src0_start_row);
1269
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1270
+ src0_stride, src0_row_size, 1);
1271
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1272
+
1273
+ #pragma unroll(2)
1274
+ for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1275
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
1276
+ float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1277
+ mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1278
+ }
1279
+ }
1280
+
1281
+ t2 = HAP_perf_get_qtimer_count();
1282
+
1283
+ FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
1284
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1285
+ src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1286
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1287
+ }
1288
+
1289
+ // q8x4x2 src1 tensor is already in VTCM spad
1290
+ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1291
+ htp_matmul_preamble;
1292
+
1293
+ const uint32_t src0_nrows = ne01;
1294
+
1295
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1296
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1297
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1298
+
1299
+ // no work for this thread
1300
+ if (src0_start_row >= src0_end_row) {
1301
+ return;
1302
+ }
1303
+
1304
+ const size_t dst_row_size = nb1;
1305
+ const size_t src0_row_size = nb01;
1306
+ const size_t src1_row_size = nb11;
1307
+
1308
+ const size_t src0_stride = src0_spad->stride;
1309
+ const size_t src1_stride = src1_spad->stride;
1310
+
1311
+ // Per-thread VTCM scratchpads for all tensors
1312
+ // Note that the entire src1 tensor is already in VTCM
1313
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1314
+ uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1315
+ uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1316
+ uint8_t * src1_data = src1_spad->data;
1317
+
1318
+ uint64_t t1, t2;
1319
+ t1 = HAP_perf_get_qtimer_count();
1320
+
1321
+ float * tmp = (float *) spad_dst;
1322
+
1323
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1324
+ const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1325
+ float * restrict dst_col = (float *) dst->data;
1326
+
1327
+ // Prefill spad with 2x src0 rows
1328
+ #pragma unroll(2)
1329
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1330
+ const uint32_t is0 = (ir0 - src0_start_row);
1331
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
1332
+ break;
1333
+ }
1334
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1335
+ src0_stride, src0_row_size, 2);
1336
+ }
1337
+
1338
+ // Process src0 rows
1339
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1340
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1341
+ mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
1342
+
1343
+ // Prefetch next (n + spad_nrows) row
1344
+ const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1345
+ const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1346
+ if (pr0 < src0_end_row_x2) {
1347
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
1348
+ src0_stride, src0_row_size, 2);
1349
+ }
1350
+ }
1351
+
1352
+ // Process the last row (if any)
1353
+ if (src0_end_row != src0_end_row_x2) {
1354
+ const uint32_t ir0 = src0_end_row_x2;
1355
+ const uint32_t is0 = (ir0 - src0_start_row);
1356
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
1357
+ src0_stride, src0_row_size, 1);
1358
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1359
+ mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
1360
+ }
1361
+
1362
+ hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
1363
+
1364
+ t2 = HAP_perf_get_qtimer_count();
1365
+
1366
+ FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
1367
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1368
+ src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1369
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1370
+ }
1371
+
1372
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]
1373
+
1374
+ struct mmid_row_mapping {
1375
+ uint32_t i1;
1376
+ uint32_t i2;
1377
+ };
1378
+
1379
+ // src1 tensor is already in VTCM spad
1380
+ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1381
+ htp_matmul_preamble;
1382
+
1383
+ struct htp_tensor * restrict ids = &octx->src2;
1384
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
1385
+
1386
+ uint64_t t1, t2;
1387
+ t1 = HAP_perf_get_qtimer_count();
1388
+
1389
+ const uint32_t src0_nrows = ne01; // src0 rows per expert
1390
+ const uint32_t src1_nrows = ne11;
1391
+
1392
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1393
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1394
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1395
+
1396
+ // no work for this thread
1397
+ if (src0_start_row >= src0_end_row) {
1398
+ return;
1399
+ }
1400
+
1401
+ const uint32_t n_ids = ids->ne[0]; // n_expert_used
1402
+ const uint32_t n_as = ne02; // n_expert
1403
+
1404
+ const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
1405
+ const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
1406
+
1407
+ const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
1408
+ const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
1409
+
1410
+ const size_t dst_row_size = nb1;
1411
+ const size_t src0_row_size = nb01;
1412
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
1413
+
1414
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1415
+
1416
+ // Per-thread VTCM scratchpads for all tensors
1417
+ // Note that the entire src1 tensor is already in VTCM
1418
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1419
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1420
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1421
+ uint8_t * restrict src1_data = src1_spad->data;
1422
+
1423
+ for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
1424
+ const int32_t cne1 = matrix_row_counts[cur_a];
1425
+
1426
+ if (cne1 == 0) {
1427
+ continue;
1428
+ }
1429
+
1430
+ const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
1431
+
1432
+ // Prefill spad with src0 rows
1433
+ #pragma unroll(4)
1434
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1435
+ const int is0 = (ir0 - src0_start_row);
1436
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
1437
+ break;
1438
+ }
1439
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1440
+ src0_row_size_padded, src0_row_size, 2);
1441
+ }
1442
+
1443
+ // Process src0 rows
1444
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1445
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1446
+
1447
+ for (uint32_t cid = 0; cid < cne1; ++cid) {
1448
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1449
+ const int rm1 = row_mapping.i1; // expert idx
1450
+ const int rm2 = row_mapping.i2; // token idx
1451
+
1452
+ const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1453
+ const uint8_t * restrict src1_col =
1454
+ (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1455
+ float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1456
+
1457
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
1458
+ }
1459
+
1460
+ // Prefetch next (n + spad_nrows) row
1461
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1462
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1463
+ if (pr0 < src0_end_row_x2) {
1464
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
1465
+ src0_row_size_padded, src0_row_size, 2);
1466
+ }
1467
+ }
1468
+
1469
+ // Process the last row (if any)
1470
+ if (src0_end_row != src0_end_row_x2) {
1471
+ uint32_t ir0 = src0_end_row_x2;
1472
+ const uint32_t is0 = (ir0 - src0_start_row);
1473
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1474
+ src0_row_size_padded, src0_row_size, 1);
1475
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1476
+
1477
+ for (uint32_t cid = 0; cid < cne1; ++cid) {
1478
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1479
+ const int rm1 = row_mapping.i1; // expert idx
1480
+ const int rm2 = row_mapping.i2; // token idx
1481
+
1482
+ const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1483
+ const uint8_t * restrict src1_col =
1484
+ (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1485
+ float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1486
+
1487
+ mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1488
+ }
1489
+ }
1490
+ }
1491
+
1492
+ t2 = HAP_perf_get_qtimer_count();
1493
+
1494
+ FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
1495
+ ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1496
+ src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
1497
+ dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1498
+ }
1499
+
1500
+ // src1 tensor is already in VTCM spad
1501
+ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
1502
+ htp_matmul_preamble;
1503
+
1504
+ struct htp_tensor * restrict ids = &octx->src2;
1505
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
1506
+
1507
+ uint64_t t1, t2;
1508
+ t1 = HAP_perf_get_qtimer_count();
1509
+
1510
+ const uint32_t src0_nrows = ne01; // src0 rows per expert
1511
+
1512
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1513
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1514
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1515
+
1516
+ // no work for this thread
1517
+ if (src0_start_row >= src0_end_row) {
1518
+ return;
1519
+ }
1520
+
1521
+ assert(ne13 % ne03 == 0);
1522
+
1523
+ const size_t dst_row_size = nb1;
1524
+ const size_t src0_row_size = nb01;
1525
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
1526
+
1527
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1528
+
1529
+ const uint32_t n_aids = src2->ne[0]; // num activated experts
1530
+ const uint32_t n_ids = ne02; // num experts
1531
+
1532
+ // Per-thread VTCM scratchpads for all tensors
1533
+ // Note that the entire src1 tensor is already in VTCM
1534
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1535
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1536
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1537
+ uint8_t * restrict src1_data = src1_spad->data;
1538
+
1539
+ for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert
1540
+ const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);
1541
+ assert(eid < n_ids);
1542
+
1543
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;
1544
+ const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1545
+ float * restrict dst_row = (float *) (dst->data + ie1 * nb1);
1546
+
1547
+ // Prefill spad with src0 rows
1548
+ #pragma unroll(4)
1549
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1550
+ const int is0 = (ir0 - src0_start_row);
1551
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
1552
+ break;
1553
+ }
1554
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1555
+ src0_row_size_padded, src0_row_size, 2);
1556
+ }
1557
+
1558
+ // Process src0 rows
1559
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1560
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1561
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
1562
+
1563
+ // Prefetch next (n + spad_nrows) row
1564
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
1565
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
1566
+ if (pr0 < src0_end_row_x2) {
1567
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
1568
+ src0_row_size_padded, src0_row_size, 2);
1569
+ }
1570
+ }
1571
+
1572
+ // Process the last row (if any)
1573
+ if (src0_end_row != src0_end_row_x2) {
1574
+ uint32_t ir0 = src0_end_row_x2;
1575
+ const uint32_t is0 = (ir0 - src0_start_row);
1576
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
1577
+ src0_row_size_padded, src0_row_size, 1);
1578
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
1579
+ mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1580
+ }
1581
+ }
1582
+
1583
+ t2 = HAP_perf_get_qtimer_count();
1584
+
1585
+ FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
1586
+ ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1587
+ src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
1588
+ dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1589
+ }
1590
+
1591
+ // *** dynamic quant
1592
+
1593
+ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1594
+ assert((unsigned long) x % 128 == 0);
1595
+ assert((unsigned long) y_q % 128 == 0);
1596
+
1597
+ HVX_Vector * vx = (HVX_Vector *) x;
1598
+ HVX_Vector zero = Q6_V_vsplat_R(0);
1599
+
1600
+ // Use reduce max fp32 to find max(abs(e)) first
1601
+ HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
1602
+ HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
1603
+ HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
1604
+ HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
1605
+ // Load and convert into QF32
1606
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1607
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
1608
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
1609
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
1610
+
1611
+ // Convert to QF32
1612
+ HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
1613
+ HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
1614
+ HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
1615
+ HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
1616
+
1617
+ // Combine and convert to fp16
1618
+ HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
1619
+ HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
1620
+
1621
+ // Convert into fp16
1622
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
1623
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1624
+
1625
+ // Replicate first fp16 scale across all lanes
1626
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
1627
+ vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
1628
+ vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
1629
+
1630
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1631
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1632
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
1633
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
1634
+
1635
+ hvx_vec_store_u(y_d + 0, 2, vd01_hf);
1636
+ HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
1637
+ hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
1638
+
1639
+ hvx_vec_store_u(y_d + 4, 2, vd23_hf);
1640
+ rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
1641
+ hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
1642
+
1643
+ // Divide input by the scale
1644
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
1645
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
1646
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
1647
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
1648
+
1649
+ // Convert to int8
1650
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
1651
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
1652
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
1653
+
1654
+ *(HVX_Vector *) y_q = vx_i8;
1655
+ }
1656
+
1657
+ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1658
+ assert((unsigned long) x % 128 == 0);
1659
+ assert((unsigned long) y_q % 128 == 0);
1660
+
1661
+ HVX_Vector * vx = (HVX_Vector *) x;
1662
+
1663
+ // Load and convert into QF32
1664
+ HVX_Vector zero = Q6_V_vsplat_R(0);
1665
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1666
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
1667
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
1668
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
1669
+
1670
+ // Convert into fp16
1671
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
1672
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1673
+
1674
+ // Compute max and scale
1675
+ HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
1676
+ HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
1677
+
1678
+ // Replicate first fp16 scale across all lanes
1679
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
1680
+ vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
1681
+ vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
1682
+
1683
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1684
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1685
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
1686
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
1687
+
1688
+ hvx_vec_store_u(y_d + 0, 4, vd01_hf);
1689
+ hvx_vec_store_u(y_d + 4, 4, vd23_hf);
1690
+
1691
+ // Divide input by the scale
1692
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
1693
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
1694
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
1695
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
1696
+
1697
+ // Convert to int8
1698
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
1699
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
1700
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
1701
+
1702
+ *(HVX_Vector *) y_q = vx_i8;
1703
+ }
1704
+
1705
+ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1706
+ assert((unsigned long) x % 128 == 0);
1707
+ assert((unsigned long) y_q % 128 == 0);
1708
+
1709
+ HVX_Vector * vx = (HVX_Vector *) x;
1710
+
1711
+ // Load and convert into QF32
1712
+ HVX_Vector zero = Q6_V_vsplat_R(0);
1713
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1714
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
1715
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
1716
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
1717
+
1718
+ // Convert into fp16
1719
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
1720
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1721
+
1722
+ // Compute max and scale
1723
+ HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
1724
+ vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf);
1725
+
1726
+ // Replicate first fp16 scale across all lanes
1727
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
1728
+ vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
1729
+
1730
+ HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1731
+ HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
1732
+
1733
+ *(HVX_UVector *) y_d = vd_hf;
1734
+
1735
+ // Divide input by the scale
1736
+ HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf);
1737
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
1738
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
1739
+
1740
+ // Convert to int8
1741
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
1742
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
1743
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
1744
+
1745
+ *(HVX_Vector *) y_q = vx_i8;
1746
+ }
1747
+
1748
+ // Overrides input x
1749
+ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
1750
+ assert(k % 32 == 0);
1751
+ const uint32_t qk = QK_Q8_0x4x2;
1752
+ const uint32_t nb = (k + qk - 1) / qk;
1753
+
1754
+ const uint32_t qrow_size = k; // int8
1755
+
1756
+ const uint32_t dblk_size = 8 * 2; // 8x __fp16
1757
+ const uint32_t qblk_size = QK_Q8_0x4x2; // int8
1758
+
1759
+ uint8_t * restrict y_q = (y + 0); // quants first
1760
+ uint8_t * restrict y_d = (y + qrow_size); // then scales
1761
+
1762
+ // Temp scales override input since we're working off of the aligned temp buffer in VTCM
1763
+ uint8_t * restrict t_d = (uint8_t *) x;
1764
+
1765
+ for (uint32_t i = 0; i < nb; i++) {
1766
+ #if FP32_QUANTIZE_GROUP_SIZE == 32
1767
+ quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
1768
+ quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
1769
+ #elif FP32_QUANTIZE_GROUP_SIZE == 64
1770
+ quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
1771
+ quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
1772
+ #elif FP32_QUANTIZE_GROUP_SIZE == 128
1773
+ quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
1774
+ quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
1775
+ #else
1776
+ #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
1777
+ #endif
1778
+ }
1779
+
1780
+ // now copy the scales into final location
1781
+ hvx_copy_fp16_ua(y_d, t_d, nb * 8);
1782
+ }
1783
+
1784
+ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
1785
+ uint8_t * restrict dst,
1786
+ struct htp_spad * spad,
1787
+ uint32_t nth,
1788
+ uint32_t ith,
1789
+ uint32_t nrows_per_thread) {
1790
+
1791
+ uint64_t t1 = HAP_perf_get_qtimer_count();
1792
+
1793
+ const uint32_t ne0 = src->ne[0];
1794
+ const uint32_t ne1 = src->ne[1];
1795
+ const uint32_t ne2 = src->ne[2];
1796
+ const uint32_t ne3 = src->ne[3];
1797
+
1798
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
1799
+
1800
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
1801
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
1802
+
1803
+ const size_t src_row_size = src->nb[1];
1804
+ const size_t dst_row_size = q8x4x2_row_size(ne0);
1805
+
1806
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
1807
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
1808
+ uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
1809
+
1810
+ const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
1811
+ memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
1812
+
1813
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
1814
+ htp_l2fetch(src_data, 2, src_row_size, src_row_size);
1815
+ hvx_copy_fp32_aa(tmp_data, src_data, ne0);
1816
+
1817
+ // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
1818
+ quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0);
1819
+ dst_data += dst_row_size;
1820
+ src_data += src_row_size;
1821
+ }
1822
+
1823
+ uint64_t t2 = HAP_perf_get_qtimer_count();
1824
+
1825
+ FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
1826
+ ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1827
+ }
1828
+
1829
+ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
1830
+ uint32_t nrows_per_thread, uint32_t dst_stride) {
1831
+
1832
+ uint64_t t1 = HAP_perf_get_qtimer_count();
1833
+
1834
+ const uint32_t ne0 = src->ne[0];
1835
+ const uint32_t ne1 = src->ne[1];
1836
+ const uint32_t ne2 = src->ne[2];
1837
+ const uint32_t ne3 = src->ne[3];
1838
+
1839
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
1840
+
1841
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
1842
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
1843
+
1844
+ const size_t src_row_size = ne0 * sizeof(float);
1845
+ const size_t src_stride = src->nb[1];
1846
+
1847
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
1848
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
1849
+
1850
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
1851
+ htp_l2fetch(src_data, 2, src_row_size, src_stride);
1852
+ hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
1853
+
1854
+ dst_data += dst_stride;
1855
+ src_data += src_stride;
1856
+ }
1857
+
1858
+ uint64_t t2 = HAP_perf_get_qtimer_count();
1859
+
1860
+ FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
1861
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1862
+ }
1863
+
1864
+ // TODO just a plain copy that should be done via the DMA during the Op setup
1865
+ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
1866
+ uint32_t nrows_per_thread, uint32_t dst_stride) {
1867
+
1868
+ uint64_t t1 = HAP_perf_get_qtimer_count();
1869
+
1870
+ const uint32_t ne0 = src->ne[0];
1871
+ const uint32_t ne1 = src->ne[1];
1872
+ const uint32_t ne2 = src->ne[2];
1873
+ const uint32_t ne3 = src->ne[3];
1874
+
1875
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
1876
+
1877
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
1878
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
1879
+
1880
+ const size_t src_row_size = ne0 * sizeof(float);
1881
+ const size_t src_stride = src->nb[1];
1882
+
1883
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
1884
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
1885
+
1886
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
1887
+ htp_l2fetch(src_data, 2, src_row_size, src_stride);
1888
+ hvx_copy_fp16_au(dst_data, src_data, ne0);
1889
+
1890
+ dst_data += dst_stride;
1891
+ src_data += src_stride;
1892
+ }
1893
+
1894
+ uint64_t t2 = HAP_perf_get_qtimer_count();
1895
+
1896
+ FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
1897
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1898
+ }
1899
+
1900
+ static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
1901
+ struct htp_ops_context * octx = data;
1902
+ quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
1903
+ }
1904
+
1905
+ static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
1906
+ struct htp_ops_context * octx = data;
1907
+ quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
1908
+ }
1909
+
1910
+ static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
1911
+ struct htp_ops_context * octx = data;
1912
+ quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
1913
+ }
1914
+
1915
+ // ** matmul/matvec callbacks for worker_pool
1916
+
1917
+ static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1918
+ struct htp_ops_context * octx = data;
1919
+
1920
+ struct htp_matmul_type mt;
1921
+ mt.type = "q4x4x2-q8x4x2";
1922
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1923
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1924
+
1925
+ matvec_2d(&mt, octx, n, i);
1926
+ }
1927
+
1928
+ static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1929
+ struct htp_ops_context * octx = data;
1930
+
1931
+ struct htp_matmul_type mt;
1932
+ mt.type = "q4x4x2-q8x4x2";
1933
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1934
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1935
+
1936
+ matmul_2d(&mt, octx, n, i);
1937
+ }
1938
+
1939
+ static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1940
+ struct htp_ops_context * octx = data;
1941
+
1942
+ struct htp_matmul_type mt;
1943
+ mt.type = "q8x4x2-q8x4x2";
1944
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1945
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1946
+
1947
+ matvec_2d(&mt, octx, n, i);
1948
+ }
1949
+
1950
+ static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1951
+ struct htp_ops_context * octx = data;
1952
+
1953
+ struct htp_matmul_type mt;
1954
+ mt.type = "q8x4x2-q8x4x2";
1955
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1956
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1957
+
1958
+ matmul_2d(&mt, octx, n, i);
1959
+ }
1960
+
1961
+ static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1962
+ struct htp_ops_context * octx = data;
1963
+
1964
+ struct htp_matmul_type mt;
1965
+ mt.type = "mxfp4x4x2-q8x4x2";
1966
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1967
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1968
+
1969
+ matvec_2d(&mt, octx, n, i);
1970
+ }
1971
+
1972
+ static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1973
+ struct htp_ops_context * octx = data;
1974
+
1975
+ struct htp_matmul_type mt;
1976
+ mt.type = "mxfp4x4x2-q8x4x2";
1977
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1978
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1979
+
1980
+ matmul_2d(&mt, octx, n, i);
1981
+ }
1982
+
1983
+ static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
1984
+ struct htp_ops_context * octx = data;
1985
+
1986
+ struct htp_matmul_type mt;
1987
+ mt.type = "f16-f16";
1988
+ mt.vec_dot = vec_dot_f16_f16_aa;
1989
+ mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
1990
+
1991
+ matvec_2d(&mt, octx, n, i);
1992
+ }
1993
+
1994
+ static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
1995
+ struct htp_ops_context * octx = data;
1996
+
1997
+ struct htp_matmul_type mt;
1998
+ mt.type = "f16-f16";
1999
+ mt.vec_dot = vec_dot_f16_f16_aa;
2000
+ mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
2001
+
2002
+ matmul_2d(&mt, octx, n, i);
2003
+ }
2004
+
2005
+ static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
2006
+ struct htp_ops_context * octx = data;
2007
+
2008
+ struct htp_matmul_type mt;
2009
+ mt.type = "f16-f32";
2010
+ mt.vec_dot = vec_dot_f16_f32_uu;
2011
+
2012
+ matmul_4d(&mt, octx, n, i);
2013
+ }
2014
+
2015
+ static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
2016
+ struct htp_ops_context * octx = data;
2017
+
2018
+ struct htp_matmul_type mt;
2019
+ mt.type = "f16-f16";
2020
+ mt.vec_dot = vec_dot_f16_f16_uu;
2021
+
2022
+ matmul_4d(&mt, octx, n, i);
2023
+ }
2024
+
2025
+ // ** matmul-id callbacks for worker_pool
2026
+
2027
+ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2028
+ struct htp_ops_context * octx = data;
2029
+
2030
+ struct htp_matmul_type mt;
2031
+ mt.type = "q4x4x2-q8x4x2";
2032
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
2033
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
2034
+
2035
+ matvec_id(&mt, octx, n, i);
2036
+ }
2037
+
2038
+ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2039
+ struct htp_ops_context * octx = data;
2040
+
2041
+ struct htp_matmul_type mt;
2042
+ mt.type = "q4x4x2-q8x4x2";
2043
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
2044
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
2045
+
2046
+ matmul_id(&mt, octx, n, i);
2047
+ }
2048
+
2049
+ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2050
+ struct htp_ops_context * octx = data;
2051
+
2052
+ struct htp_matmul_type mt;
2053
+ mt.type = "q8x4x2-q8x4x2";
2054
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
2055
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
2056
+
2057
+ matvec_id(&mt, octx, n, i);
2058
+ }
2059
+
2060
+ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2061
+ struct htp_ops_context * octx = data;
2062
+
2063
+ struct htp_matmul_type mt;
2064
+ mt.type = "q8x4x2-q8x4x2";
2065
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
2066
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
2067
+
2068
+ matmul_id(&mt, octx, n, i);
2069
+ }
2070
+
2071
+ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2072
+ struct htp_ops_context * octx = data;
2073
+
2074
+ struct htp_matmul_type mt;
2075
+ mt.type = "mxfp4x4x2-q8x4x2";
2076
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
2077
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
2078
+
2079
+ matvec_id(&mt, octx, n, i);
2080
+ }
2081
+
2082
+ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
2083
+ struct htp_ops_context * octx = data;
2084
+
2085
+ struct htp_matmul_type mt;
2086
+ mt.type = "mxfp4x4x2-q8x4x2";
2087
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
2088
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
2089
+
2090
+ matmul_id(&mt, octx, n, i);
2091
+ }
2092
+
2093
+ // ** main matmul entry point
2094
+
2095
+ static inline bool htp_is_permuted(const struct htp_tensor * t) {
2096
+ return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
2097
+ }
2098
+
2099
+ int op_matmul(struct htp_ops_context * octx) {
2100
+ htp_matmul_tensors_preamble;
2101
+
2102
+ const char * op_type;
2103
+
2104
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
2105
+ const uint32_t src1_nrows = ne11 * ne12 * ne13;
2106
+
2107
+ const size_t src0_row_size = nb01;
2108
+ const size_t dst_row_size = nb1;
2109
+ size_t src1_row_size = nb11;
2110
+
2111
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
2112
+ size_t src1_row_size_padded;
2113
+
2114
+ worker_callback_t quant_job_func;
2115
+ worker_callback_t matmul_job_func;
2116
+
2117
+ bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
2118
+
2119
+ switch (src0->type) {
2120
+ case HTP_TYPE_Q4_0:
2121
+ op_type = "q4x4x2-fp32";
2122
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2123
+ if (src1_nrows > 1) {
2124
+ matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
2125
+ } else {
2126
+ matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
2127
+ }
2128
+
2129
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2130
+
2131
+ // Entire src1 tensor is placed into the VTCM
2132
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2133
+
2134
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2135
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2136
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2137
+
2138
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2139
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2140
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2141
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2142
+ }
2143
+
2144
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2145
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2146
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2147
+ break;
2148
+
2149
+ case HTP_TYPE_Q8_0:
2150
+ op_type = "q8x4x2-fp32";
2151
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2152
+ if (src1_nrows > 1) {
2153
+ matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
2154
+ } else {
2155
+ matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
2156
+ }
2157
+
2158
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2159
+
2160
+ // Entire src1 tensor is placed into the VTCM
2161
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2162
+
2163
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2164
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2165
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2166
+
2167
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2168
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2169
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2170
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2171
+ }
2172
+
2173
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2174
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2175
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2176
+ break;
2177
+
2178
+ case HTP_TYPE_MXFP4:
2179
+ op_type = "mxfp4x4x2-f32";
2180
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2181
+ if (src1_nrows > 1) {
2182
+ matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
2183
+ } else {
2184
+ matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
2185
+ }
2186
+
2187
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2188
+
2189
+ // Entire src1 tensor is placed into the VTCM
2190
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2191
+
2192
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2193
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2194
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2195
+
2196
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2197
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2198
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2199
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2200
+ }
2201
+
2202
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2203
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2204
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2205
+ break;
2206
+
2207
+ case HTP_TYPE_F16:
2208
+ {
2209
+ // Try optimized f16-f16 path first (src1 in VTCM)
2210
+ const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128);
2211
+ const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
2212
+ const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
2213
+ const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
2214
+
2215
+ const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
2216
+
2217
+ // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
2218
+ // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
2219
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
2220
+ const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
2221
+
2222
+ if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
2223
+ // Optimized path
2224
+ op_type = "f16-f16";
2225
+ quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
2226
+ if (src1_nrows > 1) {
2227
+ matmul_job_func = htp_matmul_2d_f16_f16;
2228
+ } else {
2229
+ matmul_job_func = htp_matvec_2d_f16_f16;
2230
+ }
2231
+
2232
+ src1_row_size = f16_src1_row_size; // row size post quantization
2233
+
2234
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2235
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2236
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2237
+
2238
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2239
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2240
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2241
+ } else {
2242
+ // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
2243
+ quant_job_func = NULL;
2244
+ if (src1->type == HTP_TYPE_F32) {
2245
+ op_type = "f16-f32";
2246
+ matmul_job_func = htp_matmul_4d_f16_f32;
2247
+ } else {
2248
+ op_type = "f16-f16";
2249
+ matmul_job_func = htp_matmul_4d_f16_f16;
2250
+ }
2251
+
2252
+ src1_row_size = nb11; // original row size in DDR
2253
+
2254
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2255
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
2256
+ octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
2257
+
2258
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2259
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
2260
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2261
+
2262
+ // Init fastdiv for matmul_4d (supports broadcasting)
2263
+ octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
2264
+ octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
2265
+ octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
2266
+ octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
2267
+
2268
+ need_quant = false;
2269
+ }
2270
+ }
2271
+ break;
2272
+
2273
+ default:
2274
+ return HTP_STATUS_NO_SUPPORT;
2275
+ }
2276
+
2277
+ // VTCM scratchpads for all tensors
2278
+ size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2279
+
2280
+ FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
2281
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
2282
+
2283
+ FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
2284
+ src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
2285
+ dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
2286
+
2287
+ // Make sure the reserved vtcm size is sufficient
2288
+ if (octx->ctx->vtcm_size < spad_size) {
2289
+ FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
2290
+ octx->ctx->vtcm_size, spad_size);
2291
+ return HTP_STATUS_VTCM_TOO_SMALL;
2292
+ }
2293
+
2294
+ octx->src0_spad.data = octx->ctx->vtcm_base;
2295
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2296
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2297
+
2298
+ octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2299
+ octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
2300
+
2301
+ octx->src0_spad.stride = src0_row_size_padded;
2302
+ octx->src1_spad.stride = src1_row_size;
2303
+
2304
+ if (need_quant) {
2305
+ // Run quant jobs
2306
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2307
+ octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2308
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
2309
+ }
2310
+
2311
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2312
+ // Run matmul jobs
2313
+ const uint32_t n_matmul_jobs = octx->n_threads;
2314
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
2315
+ }
2316
+
2317
+ return HTP_STATUS_OK;
2318
+ }
2319
+
2320
+ // ** main matmul-id entry point
2321
+
2322
+ int op_matmul_id(struct htp_ops_context * octx) {
2323
+ htp_matmul_tensors_preamble;
2324
+
2325
+ struct htp_tensor * restrict ids = &octx->src2;
2326
+
2327
+ const char * op_type;
2328
+
2329
+ worker_callback_t quant_job_func;
2330
+ worker_callback_t matmul_id_job_func;
2331
+
2332
+ const size_t src0_row_size = nb01;
2333
+ const size_t dst_row_size = nb1;
2334
+
2335
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
2336
+
2337
+ const uint32_t src0_nrows = ne01; // per expert
2338
+ const uint32_t src1_nrows = ne11 * ne12 * ne13;
2339
+
2340
+ size_t src1_row_size;
2341
+ size_t src1_row_size_padded;
2342
+
2343
+ // row groups
2344
+ const int n_ids = ids->ne[0]; // n_expert_used
2345
+ const int n_as = ne02; // n_expert
2346
+
2347
+ size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
2348
+ size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
2349
+
2350
+ switch (src0->type) {
2351
+ case HTP_TYPE_Q4_0:
2352
+ op_type = "q4x2x2-f32";
2353
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2354
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2355
+ if (src1_nrows > 1) {
2356
+ matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
2357
+ } else {
2358
+ matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
2359
+ }
2360
+
2361
+ // Entire src1 tensor is placed into the VTCM
2362
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2363
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2364
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2365
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2366
+ octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2367
+
2368
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2369
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2370
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2371
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2372
+ }
2373
+
2374
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2375
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2376
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2377
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2378
+ break;
2379
+
2380
+ case HTP_TYPE_Q8_0:
2381
+ op_type = "q8x2x2-f32";
2382
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2383
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2384
+ if (src1_nrows > 1) {
2385
+ matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
2386
+ } else {
2387
+ matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
2388
+ }
2389
+
2390
+ // Entire src1 tensor is placed into the VTCM
2391
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2392
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2393
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2394
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2395
+ octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2396
+
2397
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2398
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2399
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2400
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2401
+ }
2402
+
2403
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2404
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2405
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2406
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2407
+ break;
2408
+
2409
+ case HTP_TYPE_MXFP4:
2410
+ op_type = "mxfp4x2x2-f32";
2411
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2412
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2413
+ if (src1_nrows > 1) {
2414
+ matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
2415
+ } else {
2416
+ matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
2417
+ }
2418
+
2419
+ // Entire src1 tensor is placed into the VTCM
2420
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2421
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
2422
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2423
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2424
+ octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2425
+
2426
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2427
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2428
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2429
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2430
+ }
2431
+
2432
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2433
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2434
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2435
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2436
+ break;
2437
+
2438
+ default:
2439
+ return HTP_STATUS_NO_SUPPORT;
2440
+ }
2441
+
2442
+ size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2443
+
2444
+ FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
2445
+ octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
2446
+
2447
+ FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
2448
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
2449
+ ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
2450
+ src1->data, dst->data);
2451
+
2452
+ // Make sure the reserved vtcm size is sufficient
2453
+ if (octx->ctx->vtcm_size < spad_size) {
2454
+ FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
2455
+ octx->ctx->vtcm_size, spad_size);
2456
+ return HTP_STATUS_VTCM_TOO_SMALL;
2457
+ }
2458
+
2459
+ octx->src0_spad.data = octx->ctx->vtcm_base;
2460
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2461
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2462
+ octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
2463
+
2464
+ octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2465
+ octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
2466
+
2467
+ if (src1_nrows > 1) {
2468
+ // initialize matrix_row_counts and map
2469
+ uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
2470
+ struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
2471
+
2472
+ memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
2473
+
2474
+ // group rows by src0 matrix
2475
+ for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
2476
+ for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
2477
+ const uint32_t i02 =
2478
+ *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2479
+
2480
+ assert(i02 >= 0 && i02 < n_as);
2481
+
2482
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
2483
+ matrix_row_counts[i02] += 1;
2484
+ }
2485
+ }
2486
+ }
2487
+
2488
+ // Setup worker pool callbacks
2489
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
2490
+ // Run quant jobs
2491
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2492
+ octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2493
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
2494
+ }
2495
+
2496
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2497
+ // Run matmul-id jobs
2498
+ const uint32_t n_matmul_jobs = octx->n_threads;
2499
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
2500
+ }
2501
+
2502
+ return HTP_STATUS_OK;
2503
+ }