whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -2,166 +2,288 @@
2
2
  #pragma clang diagnostic ignored "-Wunused-function"
3
3
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
4
 
5
- #ifdef HTP_DEBUG
6
- # define FARF_HIGH 1
7
- #endif
5
+ #include <assert.h>
8
6
  #include <HAP_farf.h>
9
- #include <HAP_mem.h>
10
7
  #include <HAP_perf.h>
11
- #include <hexagon_protos.h>
12
- #include <hexagon_types.h>
13
8
  #include <math.h>
14
9
  #include <string.h>
15
10
 
11
+ #include "hex-dma.h"
12
+ #include "hvx-utils.h"
13
+ #include "hvx-dump.h"
14
+
16
15
  #define GGML_COMMON_DECL_C
17
16
  #include "ggml-common.h"
18
17
  #include "htp-ctx.h"
19
- #include "htp-dma.h"
20
18
  #include "htp-msg.h"
21
19
  #include "htp-ops.h"
22
- #include "hvx-utils.h"
23
- #include "ops-utils.h"
24
20
 
25
- // Dot product of FP32 and FP16 vectors, accumulating to float
26
- static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
27
- const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
21
+ // Must be multiple of 32
22
+ #define FLASH_ATTN_BLOCK_SIZE (32 * 2)
23
+
24
+ // This is a bit of a hack because the compiler is strugling to properly inline
25
+ // the default hvx_vec_f32_to_f16 with output into the local array.
26
+ static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
27
+ {
28
+ *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
29
+ }
30
+
31
+ // Dot product of two F16 vectors, accumulating to float
32
+ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
28
33
  const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
34
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
29
35
 
30
36
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
31
37
  uint32_t nloe = n % VLEN_FP16; // leftover elements
32
38
 
33
- const HVX_Vector zero = Q6_V_vsplat_R(0);
34
- HVX_Vector rsum = Q6_V_vsplat_R(0);
39
+ HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
35
40
 
36
41
  uint32_t i = 0;
37
42
 
38
43
  #pragma unroll(4)
39
44
  for (i = 0; i < nvec; i++) {
40
- // Load y (fp32) and convert into fp16
41
- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
42
- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
43
- HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
44
-
45
- // Load x (fp16)
46
- HVX_Vector x_hf = vx[i];
45
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
46
+ }
47
47
 
48
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
48
+ if (nloe) {
49
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
50
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
51
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
49
52
 
50
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
53
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
51
54
  }
52
55
 
53
- if (nloe) {
54
- // Load y (fp32) and convert into fp16
55
- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
56
- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
57
- HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
56
+ HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
57
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
58
+ hvx_vec_store_u(r, 4, rsum);
59
+ }
58
60
 
59
- // Load x (fp16)
60
- HVX_Vector x_hf = vx[i];
61
+ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
62
+ const uint8_t * restrict x,
63
+ const size_t stride_x,
64
+ const size_t nvec,
65
+ const size_t nloe) {
66
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
67
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
68
+ const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
69
+ const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
70
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
71
+
72
+ HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
73
+ HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
74
+ HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
75
+ HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
61
76
 
62
- // Zero-out unused elements
63
- // Note that we need to clear both x and y because they may contain NANs
64
- HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
65
- x_hf = Q6_V_vand_QV(bmask, x_hf);
66
- y_hf = Q6_V_vand_QV(bmask, y_hf);
77
+ uint32_t i = 0;
67
78
 
68
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
79
+ for (i = 0; i < nvec; i++) {
80
+ HVX_Vector y_hf = vy[i];
81
+ HVX_Vector x0_hf = vx0[i];
82
+ HVX_Vector x1_hf = vx1[i];
83
+ HVX_Vector x2_hf = vx2[i];
84
+ HVX_Vector x3_hf = vx3[i];
85
+
86
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
87
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
88
+ rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
89
+ rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
90
+ }
69
91
 
70
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
92
+ if (nloe) {
93
+ // Load x (fp16) and zero-out unused elements
94
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
95
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
96
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
97
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
98
+ HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
99
+ HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
100
+
101
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
102
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
103
+ rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
104
+ rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
71
105
  }
72
106
 
73
- rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
74
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
107
+ HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
108
+ HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
109
+ HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
110
+ HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
75
111
 
76
- hvx_vec_store_u(r, 4, rsum);
112
+ HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
113
+ return hvx_vec_reduce_sum_f32x4(rsum0123);
77
114
  }
78
115
 
79
- // Dot product of two F16 vectors, accumulating to float
80
- static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
81
- const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
82
- const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
116
+ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
117
+ const uint8_t * restrict x,
118
+ const size_t stride_x,
119
+ const size_t n,
120
+ float s) {
121
+
122
+ const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
123
+ const size_t nloe = n % VLEN_FP16; // leftover elements
124
+
125
+ HVX_Vector sums; // initialize at j = 0
126
+ const size_t stride_x_4 = stride_x * 4;
127
+ for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
128
+ HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
129
+ HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
130
+ sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
131
+ x += stride_x_4;
132
+ }
133
+
134
+ sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
135
+ return Q6_Vsf_equals_Vqf32(sums);
136
+ }
137
+
138
+ // MAD: y (F32) += x (F16) * s (F16)
139
+ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) {
140
+ const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
141
+
142
+ HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
143
+ HVX_Vector * restrict vy = (HVX_Vector *) y;
83
144
 
84
145
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
85
146
  uint32_t nloe = n % VLEN_FP16; // leftover elements
86
147
 
87
- const HVX_Vector zero = Q6_V_vsplat_R(0);
88
- HVX_Vector rsum = Q6_V_vsplat_R(0);
148
+ HVX_Vector S0 = hvx_vec_splat_f16(*s);
89
149
 
90
150
  uint32_t i = 0;
91
151
 
92
- #pragma unroll(4)
93
- for (i = 0; i < nvec; i++) {
94
- HVX_Vector y_hf = vy[i];
95
- HVX_Vector x_hf = vx[i];
96
-
97
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
98
-
99
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
152
+ #pragma unroll(2)
153
+ for (i = 0; i < nvec; ++i) {
154
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
100
155
  }
101
156
 
102
157
  if (nloe) {
103
- HVX_Vector y_hf = vy[i];
158
+ HVX_VectorPair xy_p = vy_p[i];
159
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
104
160
 
105
- // Load x (fp16) and zero-out unused elements
106
- HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
107
- HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
161
+ HVX_Vector xy = Q6_V_lo_W(xy_p);
162
+ i = 2 * i; // index for vy
108
163
 
109
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
164
+ if (nloe >= VLEN_FP32) {
165
+ vy[i] = xy;
166
+ nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
167
+ }
110
168
 
111
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
169
+ if (nloe) {
170
+ hvx_vec_store_a(&vy[i], nloe * 4, xy);
171
+ }
112
172
  }
113
-
114
- rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
115
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
116
- hvx_vec_store_u(r, 4, rsum);
117
173
  }
118
174
 
119
- // MAD: y (F32) += x (F16) * v (float)
120
- static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
121
- const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
122
- HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
175
+ // MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
176
+ static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
177
+ const __fp16 * restrict s0, const __fp16 * restrict s1, int n) {
178
+ const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
179
+ const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
123
180
 
124
- uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
125
- uint32_t nloe = n % VLEN_FP16; // leftover elements
181
+ HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
182
+ HVX_Vector * restrict vy = (HVX_Vector *) y;
126
183
 
127
- HVX_Vector S = hvx_vec_splat_fp16(s);
184
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
185
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
186
+
187
+ HVX_Vector S0 = hvx_vec_splat_f16(*s0);
188
+ HVX_Vector S1 = hvx_vec_splat_f16(*s1);
128
189
 
129
190
  uint32_t i = 0;
130
- #pragma unroll(4)
191
+
192
+ #pragma unroll(2)
131
193
  for (i = 0; i < nvec; ++i) {
132
- // Multiply x * s -> pair of F32 vectors
133
- HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
134
- ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
135
- ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
194
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
195
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
136
196
  }
137
197
 
138
198
  if (nloe) {
139
- HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
199
+ HVX_VectorPair xy_p = vy_p[i];
200
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
201
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
140
202
 
141
- HVX_Vector xs = Q6_V_lo_W(xs_p);
142
- i = 2 * i; // index for ptr_y
203
+ HVX_Vector xy = Q6_V_lo_W(xy_p);
204
+ i = 2 * i; // index for vy
143
205
 
144
- if (nloe >= 32) {
145
- ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
146
- nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
206
+ if (nloe >= VLEN_FP32) {
207
+ vy[i] = xy;
208
+ nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
147
209
  }
148
210
 
149
211
  if (nloe) {
150
- HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
151
- hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
212
+ hvx_vec_store_a(&vy[i], nloe * 4, xy);
152
213
  }
153
214
  }
154
215
  }
155
216
 
156
- #define FLASH_ATTN_BLOCK_SIZE 128
217
+ struct htp_fa_context {
218
+ const struct htp_ops_context * octx;
219
+
220
+ struct fastdiv_values src0_div21;
221
+ struct fastdiv_values src0_div1;
222
+
223
+ struct fastdiv_values broadcast_rk2;
224
+ struct fastdiv_values broadcast_rk3;
225
+ struct fastdiv_values broadcast_rv2;
226
+ struct fastdiv_values broadcast_rv3;
227
+
228
+ struct fastdiv_values src3_div2;
229
+ struct fastdiv_values src3_div3;
230
+
231
+ float scale;
232
+ float max_bias;
233
+ float logit_softcap;
234
+
235
+ uint32_t n_head_log2;
236
+ float m0;
237
+ float m1;
238
+
239
+ uint32_t n_blocks;
240
+
241
+ size_t size_q_row_padded;
242
+ size_t size_k_row_padded;
243
+ size_t size_v_row_padded;
244
+
245
+ size_t size_k_block;
246
+ size_t size_v_block;
247
+ size_t size_m_block;
157
248
 
158
- static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
249
+ uint32_t qrows;
250
+ uint32_t qrows_per_thread;
251
+
252
+ bool is_q_fp32;
253
+
254
+ uint64_t t_start;
255
+ };
256
+
257
+ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
258
+ assert((size_t) dst % 128 == 0);
259
+ assert((size_t) src % 128 == 0);
260
+
261
+ const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
262
+ HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
263
+
264
+ const uint32_t nvec = n / VLEN_FP32;
265
+ const uint32_t nloe = n % VLEN_FP32;
266
+
267
+ uint32_t i = 0;
268
+ #pragma unroll(4)
269
+ for (; i < nvec; ++i) {
270
+ vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
271
+ }
272
+ if (nloe) {
273
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
274
+ hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
275
+ }
276
+ }
277
+
278
+ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
279
+ struct htp_fa_context * factx = (struct htp_fa_context *) data;
280
+ const struct htp_ops_context * octx = factx->octx;
159
281
  const struct htp_tensor * q = &octx->src0;
160
282
  const struct htp_tensor * k = &octx->src1;
161
283
  const struct htp_tensor * v = &octx->src2;
162
284
  const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
163
285
  const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
164
- struct htp_tensor * dst = &octx->dst;
286
+ const struct htp_tensor * dst = &octx->dst;
165
287
 
166
288
  const uint32_t neq0 = q->ne[0];
167
289
  const uint32_t neq1 = q->ne[1];
@@ -198,22 +320,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
198
320
  const uint32_t nb2 = dst->nb[2];
199
321
  const uint32_t nb3 = dst->nb[3];
200
322
 
201
- float scale = 1.0f;
202
- float max_bias = 0.0f;
203
- float logit_softcap = 0.0f;
204
-
205
- memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
206
- memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
207
- memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
208
-
209
- if (logit_softcap != 0) {
210
- scale /= logit_softcap;
211
- }
212
-
213
323
  // total rows in q
214
- const uint32_t nr = neq1*neq2*neq3;
215
-
216
- const uint32_t dr = (nr + nth - 1) / nth;
324
+ const uint32_t nr = factx->qrows;
325
+ const uint32_t dr = factx->qrows_per_thread;
217
326
  const uint32_t ir0 = dr * ith;
218
327
  const uint32_t ir1 = MIN(ir0 + dr, nr);
219
328
 
@@ -225,18 +334,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
225
334
  const uint32_t DV = nev0;
226
335
 
227
336
  const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
228
- const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
229
-
230
337
  const size_t size_k_row = DK * sizeof(__fp16);
231
338
  const size_t size_v_row = DV * sizeof(__fp16);
232
- const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
233
-
234
- const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
235
- const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
236
-
237
- const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
238
- const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
239
- const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
240
339
 
241
340
  // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
242
341
  uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -245,72 +344,79 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
245
344
  uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
246
345
  uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
247
346
 
248
- const uint32_t n_head = neq2;
249
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
250
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
251
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
347
+ const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
252
348
 
253
349
  for (uint32_t ir = ir0; ir < ir1; ++ir) {
254
- const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
255
- const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
350
+ const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
351
+ const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
256
352
  const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
257
353
 
258
- const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
259
- const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
354
+ const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
355
+ const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
260
356
 
261
- const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
262
- const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
357
+ const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
358
+ const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
263
359
 
264
360
  // Fetch Q row
265
361
  const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
266
- dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
267
-
268
- const uint32_t h = iq2; // head index
269
- const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
270
-
271
- float S = 0.0f; // sum
272
- float M = -INFINITY; // maximum KQ value
362
+ dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
273
363
 
274
- // Clear accumulator
275
- float * VKQ32 = (float *) spad_a;
276
- memset(VKQ32, 0, DV * sizeof(float));
364
+ // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row,
365
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
277
366
 
278
367
  const __fp16 * mp_base = NULL;
279
368
  if (mask) {
280
- const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
281
- const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
369
+ const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
370
+ const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
282
371
  mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
283
372
  }
284
373
 
285
- const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
286
-
287
374
  // Prefetch first two blocks
288
- for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
375
+ for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
289
376
  const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
290
377
  const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
291
378
 
292
379
  // K
293
380
  const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
294
- uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
295
- dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
381
+ uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
382
+ dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
296
383
 
297
384
  // V
298
385
  const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
299
- uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
300
- dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
386
+ uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
387
+ dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
301
388
 
302
389
  // Mask
303
390
  if (mask) {
304
391
  const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
305
- uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
392
+ uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
306
393
  // Mask is 1D contiguous for this row
307
394
  dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
308
395
  }
396
+
397
+ // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
398
+ // ith, ir, ib, iq1, iq2, iq3,
399
+ // size_k_row, size_v_row, current_block_size,
400
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
309
401
  }
310
402
 
311
- const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
403
+ const uint32_t h = iq2; // head index
404
+ const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
405
+
406
+ HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
407
+ HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
312
408
 
313
- for (uint32_t ib = 0; ib < n_blocks; ++ib) {
409
+ // Clear accumulator
410
+ hvx_splat_f32_a(spad_a, 0, DV);
411
+ float * VKQ32 = (float *) (spad_a + 0);
412
+
413
+ uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
414
+ if (factx->is_q_fp32) {
415
+ hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
416
+ }
417
+
418
+ const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
419
+ for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
314
420
  const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
315
421
  const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
316
422
 
@@ -319,156 +425,166 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
319
425
  uint8_t * v_base = dma_queue_pop(dma).dst; // V
320
426
  __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
321
427
 
428
+ // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u",
429
+ // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm,
430
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
431
+
322
432
  // Inner loop processing the block from VTCM
323
433
  uint32_t ic = 0;
324
434
 
325
- // Process in blocks of 32 (VLEN_FP32)
326
- for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
435
+ // Process in sub-blocks of 32 (VLEN_FP32)
436
+ HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
437
+ HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
438
+ for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
327
439
  // 1. Compute scores
328
- float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
329
- for (int j = 0; j < VLEN_FP32; ++j) {
330
- const uint32_t cur_ic = ic + j;
331
- const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
332
- if (q->type == HTP_TYPE_F32) {
333
- hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
334
- } else {
335
- hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
336
- }
337
- }
338
-
339
- HVX_Vector scores = *(HVX_Vector *) scores_arr;
440
+ HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
340
441
 
341
442
  // 2. Softcap
342
- if (logit_softcap != 0.0f) {
343
- scores = hvx_vec_tanh_fp32(scores);
344
- scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
443
+ if (factx->logit_softcap != 0.0f) {
444
+ scores = hvx_vec_tanh_f32(scores);
445
+ scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
345
446
  scores = Q6_Vsf_equals_Vqf32(scores);
346
447
  }
347
448
 
348
449
  // 3. Mask
349
450
  if (mask) {
350
451
  const __fp16 * mp = m_base + ic;
351
- HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
352
-
353
- HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
354
- HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
355
-
356
- HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
357
-
358
- HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
359
- HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
360
- scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
452
+ HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
453
+ HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
454
+ HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
455
+ scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
361
456
  scores = Q6_Vsf_equals_Vqf32(scores);
362
457
  }
363
458
 
364
- // 4. Online Softmax Update
365
- HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
366
- float m_block = hvx_vec_get_fp32(v_max);
367
-
368
- float M_old = M;
369
- float M_new = (m_block > M) ? m_block : M;
370
- M = M_new;
459
+ sb_scores[iv] = scores;
460
+ v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
461
+ }
371
462
 
372
- float ms = expf(M_old - M_new);
463
+ {
464
+ // 4. Online Softmax Update
465
+ HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
466
+ HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
467
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
468
+ M_vec = M_new_vec;
373
469
 
374
- hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
375
- S = S * ms;
470
+ hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
376
471
 
377
- HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
378
- HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
379
- HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
472
+ HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
473
+ for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
474
+ HVX_Vector scores = sb_scores[iv];
475
+ HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
476
+ HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
380
477
 
381
- HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
382
- float p_sum = hvx_vec_get_fp32(p_sum_vec);
383
- S += p_sum;
478
+ p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
384
479
 
385
- // 5. Accumulate V
386
- float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
387
- *(HVX_Vector*)p_arr = P;
480
+ // 5. Accumulate V
481
+ __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
482
+ hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
388
483
 
389
- for (int j = 0; j < VLEN_FP32; ++j) {
390
- const uint32_t cur_ic = ic + j;
391
- const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
392
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
484
+ for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
485
+ const uint32_t cur_ic = ic2 + j;
486
+ const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
487
+ hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
488
+ }
393
489
  }
394
- }
395
490
 
396
- // Leftover
397
- for (; ic < current_block_size; ++ic) {
398
- float s_val;
399
- const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
491
+ p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
492
+ S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
493
+ }
400
494
 
401
- if (q->type == HTP_TYPE_F32) {
402
- hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
403
- } else {
404
- hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
405
- }
495
+ if (ic < current_block_size) {
496
+ // Sync scalars for leftover/next block if needed
497
+ float M = hvx_vec_get_f32(M_vec);
498
+ float S = hvx_vec_get_f32(S_vec);
499
+
500
+ // Leftover
501
+ for (; ic < current_block_size; ++ic) {
502
+ float s_val;
503
+ const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
504
+ hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
505
+ if (factx->logit_softcap != 0.0f) {
506
+ s_val = factx->logit_softcap * tanhf(s_val);
507
+ }
406
508
 
407
- if (logit_softcap != 0.0f) {
408
- s_val = logit_softcap * tanhf(s_val);
409
- }
509
+ if (mask) {
510
+ const float m_val = m_base[ic];
511
+ s_val += slope * m_val;
512
+ }
410
513
 
411
- if (mask) {
412
- const float m_val = m_base[ic];
413
- s_val += slope * m_val;
414
- }
514
+ const float Mold = M;
515
+ __fp16 vs = 1.0f;
415
516
 
416
- const float Mold = M;
417
- float ms = 1.0f;
418
- float vs = 1.0f;
517
+ if (s_val > M) {
518
+ M = s_val;
519
+ HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
520
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
521
+ hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
419
522
 
420
- if (s_val > M) {
421
- M = s_val;
422
- ms = expf(Mold - M);
423
- hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
424
- } else {
425
- vs = expf(s_val - M);
426
- }
523
+ float ms = hvx_vec_get_f32(ms_vec);
524
+ S = S * ms + vs;
525
+ } else {
526
+ HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
527
+ vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
528
+ S += vs;
529
+ }
427
530
 
428
- const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
531
+ const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
429
532
 
430
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
533
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
534
+ }
431
535
 
432
- S = S * ms + vs;
536
+ M_vec = hvx_vec_splat_f32(M);
537
+ S_vec = hvx_vec_splat_f32(S);
433
538
  }
434
539
 
435
540
  // Issue DMA for next+1 block (if exists)
436
- if (ib + 2 < n_blocks) {
541
+ if (ib + 2 < factx->n_blocks) {
437
542
  const uint32_t next_ib = ib + 2;
438
543
  const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
439
544
  const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
440
545
 
441
546
  // K
442
547
  const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
443
- dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
548
+ dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
444
549
 
445
550
  // V
446
551
  const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
447
- dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
552
+ dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
448
553
 
449
554
  // Mask
450
555
  if (mask) {
451
556
  const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
452
557
  dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
453
558
  }
559
+
560
+ // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
561
+ // ith, ir, next_ib, iq1, iq2, iq3,
562
+ // size_k_row, size_v_row, next_block_size,
563
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
454
564
  }
455
565
  }
456
566
 
457
567
  // sinks
568
+ float M = hvx_vec_get_f32(M_vec);
569
+ float S = hvx_vec_get_f32(S_vec);
570
+
458
571
  if (sinks) {
459
572
  const float s = ((float *)((char *) sinks->data))[h];
460
573
 
461
- float ms = 1.0f;
462
574
  float vs = 1.0f;
463
575
 
464
576
  if (s > M) {
465
- ms = expf(M - s);
466
- hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
577
+ HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
578
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
579
+ hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
580
+
581
+ float ms = hvx_vec_get_f32(ms_vec);
582
+ S = S * ms + vs;
467
583
  } else {
468
- vs = expf(s - M);
584
+ HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
585
+ vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
586
+ S += vs;
469
587
  }
470
-
471
- S = S * ms + vs;
472
588
  }
473
589
 
474
590
  const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -484,60 +600,91 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
484
600
  uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
485
601
 
486
602
  if (dst->type == HTP_TYPE_F32) {
487
- hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
603
+ hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
488
604
  } else if (dst->type == HTP_TYPE_F16) {
489
- hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
605
+ hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
490
606
  }
491
607
  }
492
608
  }
493
609
 
494
- static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
495
- struct htp_ops_context * octx = data;
496
- flash_attn_ext_f16_thread(octx, i, n);
497
- }
498
-
499
610
  int op_flash_attn_ext(struct htp_ops_context * octx) {
500
611
  const struct htp_tensor * q = &octx->src0;
501
612
  const struct htp_tensor * k = &octx->src1;
502
613
  const struct htp_tensor * v = &octx->src2;
503
- const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
504
- struct htp_tensor * dst = &octx->dst;
614
+ const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
615
+ const struct htp_tensor * dst = &octx->dst;
505
616
 
506
617
  // Check support
507
- if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
508
- k->type != HTP_TYPE_F16 ||
509
- v->type != HTP_TYPE_F16) {
618
+ if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
510
619
  return HTP_STATUS_NO_SUPPORT;
511
620
  }
512
621
 
513
- octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
514
- octx->src0_div1 = init_fastdiv_values(q->ne[1]);
622
+ struct htp_fa_context factx;
623
+ factx.octx = octx;
515
624
 
516
- octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
517
- octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
518
- octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
519
- octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
625
+ factx.t_start = HAP_perf_get_qtimer_count();
626
+
627
+ factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
628
+ factx.src0_div1 = init_fastdiv_values(q->ne[1]);
629
+
630
+ factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
631
+ factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
632
+ factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
633
+ factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
520
634
 
521
635
  if (mask) {
522
- octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
523
- octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
636
+ factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
637
+ factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
638
+ }
639
+
640
+ factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
641
+ factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
642
+ factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
643
+ factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
644
+
645
+ size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
646
+ factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
647
+ factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
648
+ factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
649
+
650
+ factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
651
+
652
+ float scale = 1.0f;
653
+ float max_bias = 0.0f;
654
+ float logit_softcap = 0.0f;
655
+
656
+ memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
657
+ memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
658
+ memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
659
+
660
+ if (logit_softcap != 0.0f) {
661
+ scale /= logit_softcap;
524
662
  }
525
663
 
526
- size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
527
- size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
528
- size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
664
+ factx.scale = scale;
665
+ factx.max_bias = max_bias;
666
+ factx.logit_softcap = logit_softcap;
667
+
668
+ uint32_t n_head = q->ne[2];
669
+ factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
670
+ factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
671
+ factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
672
+
673
+ // total rows in q
674
+ const uint32_t neq0 = q->ne[0];
675
+ const uint32_t neq1 = q->ne[1];
676
+ const uint32_t neq2 = q->ne[2];
677
+ const uint32_t neq3 = q->ne[3];
529
678
 
530
- size_t size_q_block = size_q_row_padded * 1; // single row for now
531
- size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
532
- size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
533
- size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
679
+ factx.qrows = neq1*neq2*neq3;
680
+ factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads;
534
681
 
535
- size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
682
+ size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
536
683
 
537
684
  octx->src0_spad.size_per_thread = size_q_block * 1;
538
- octx->src1_spad.size_per_thread = size_k_block * 2;
539
- octx->src2_spad.size_per_thread = size_v_block * 2;
540
- octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
685
+ octx->src1_spad.size_per_thread = factx.size_k_block * 2;
686
+ octx->src2_spad.size_per_thread = factx.size_v_block * 2;
687
+ octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
541
688
  octx->dst_spad.size_per_thread = size_vkq_acc;
542
689
 
543
690
  octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -559,7 +706,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
559
706
  octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
560
707
 
561
708
  if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
562
- worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
709
+ worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
563
710
  }
564
711
 
565
712
  return HTP_STATUS_OK;