whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -2,41 +2,52 @@
2
2
  #pragma clang diagnostic ignored "-Wunused-function"
3
3
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
4
 
5
- #ifdef HTP_DEBUG
6
- # define FARF_HIGH 1
7
- #endif
8
-
9
5
  #include <HAP_farf.h>
10
- #include <HAP_mem.h>
11
6
  #include <HAP_perf.h>
12
- #include <HAP_ps.h>
13
- #include <hexagon_protos.h>
14
- #include <hexagon_types.h>
7
+
15
8
  #include <math.h>
16
- #include <qurt_thread.h>
17
9
  #include <string.h>
18
10
 
11
+ #include "hex-dma.h"
12
+ #include "hvx-utils.h"
13
+
19
14
  #define GGML_COMMON_DECL_C
20
15
  #include "ggml-common.h"
21
16
  #include "htp-ctx.h"
22
- #include "htp-dma.h"
23
17
  #include "htp-msg.h"
24
18
  #include "htp-ops.h"
25
- #include "hvx-utils.h"
26
- #include "ops-utils.h"
27
19
 
28
- typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
29
- const uint8_t * src1,
30
- uint8_t * data_dst,
31
- const int num_elems);
20
+ #ifndef MIN
21
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
22
+ #endif
32
23
 
33
- static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
34
- static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
24
+ // Context for binary operations
25
+ struct htp_binary_context {
26
+ struct htp_ops_context * octx;
27
+ struct fastdiv_values dim1_div;
28
+ struct fastdiv_values dim2_div;
29
+ struct fastdiv_values dim12_div;
30
+
31
+ struct fastdiv_values src1_dim1_div; // ne11
32
+ struct fastdiv_values src1_dim2_div; // ne12
33
+ struct fastdiv_values src1_dim3_div; // ne13
34
+
35
+ uint32_t nrows_per_thread;
36
+ bool split_at_ne01;
37
+ bool split_at_ne02;
38
+
39
+ // Precomputed values
40
+ uint32_t block_max;
41
+ size_t src0_row_size_aligned;
42
+ size_t src1_row_size_aligned;
43
+ size_t dst_row_size_aligned;
44
+ uint32_t src1_fetch_rows; // 1 or block_max
45
+ uint32_t src1_dma_stride; // 0 or stride
46
+ };
35
47
 
36
48
  #define htp_binary_preamble \
37
49
  const struct htp_tensor * src0 = &octx->src0; \
38
50
  const struct htp_tensor * src1 = &octx->src1; \
39
- const struct htp_tensor * src2 = &octx->src2; \
40
51
  struct htp_tensor * dst = &octx->dst; \
41
52
  \
42
53
  const uint32_t ne00 = src0->ne[0]; \
@@ -49,272 +60,752 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
49
60
  const uint32_t ne12 = src1->ne[2]; \
50
61
  const uint32_t ne13 = src1->ne[3]; \
51
62
  \
52
- const uint32_t ne0 = dst->ne[0]; \
53
- const uint32_t ne1 = dst->ne[1]; \
54
- const uint32_t ne2 = dst->ne[2]; \
55
- const uint32_t ne3 = dst->ne[3]; \
56
- \
57
- const uint32_t nb00 = src0->nb[0]; \
58
63
  const uint32_t nb01 = src0->nb[1]; \
59
64
  const uint32_t nb02 = src0->nb[2]; \
60
65
  const uint32_t nb03 = src0->nb[3]; \
61
66
  \
62
- const uint32_t nb10 = src1->nb[0]; \
63
67
  const uint32_t nb11 = src1->nb[1]; \
64
68
  const uint32_t nb12 = src1->nb[2]; \
65
69
  const uint32_t nb13 = src1->nb[3]; \
66
70
  \
67
- const uint32_t nb0 = dst->nb[0]; \
68
71
  const uint32_t nb1 = dst->nb[1]; \
69
72
  const uint32_t nb2 = dst->nb[2]; \
70
- const uint32_t nb3 = dst->nb[3]; \
71
- \
72
- const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
73
+ const uint32_t nb3 = dst->nb[3];
73
74
 
74
- static void binary_job_f32_per_thread(struct htp_ops_context * octx,
75
- uint8_t * spad_data,
76
- uint32_t nth,
77
- uint32_t ith,
78
- enum htp_op op) {
79
- htp_binary_preamble;
75
+ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
76
+ uint32_t ne01, uint32_t ne02) {
77
+ uint32_t i03, i02, i01, rem;
78
+ i03 = fastdiv(ir, &bctx->dim12_div);
79
+ rem = ir - i03 * (ne02 * ne01);
80
+ i02 = fastdiv(rem, &bctx->dim1_div);
81
+ i01 = rem - i02 * ne01;
80
82
 
81
- const size_t src0_row_size = nb01;
82
- const size_t src1_row_size = nb11;
83
- const size_t dst_row_size = nb1;
83
+ uint32_t rows_left = end_row - ir;
84
+ uint32_t block_limit = rows_left;
84
85
 
85
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
86
- const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
86
+ if (bctx->split_at_ne01) {
87
+ block_limit = MIN(block_limit, ne01 - i01);
88
+ }
89
+ if (bctx->split_at_ne02) {
90
+ uint32_t rows_in_plane = (ne02 * ne01) - rem;
91
+ block_limit = MIN(block_limit, rows_in_plane);
92
+ }
87
93
 
88
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
89
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
94
+ return MIN(bctx->block_max, block_limit);
95
+ }
90
96
 
91
- // no work for this thread
92
- if (src0_start_row >= src0_end_row) {
93
- return;
97
+ // Macro for scalar op switch
98
+ #define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
99
+ if(TYPE == HTP_TYPE_F32) { \
100
+ switch (octx->op) { \
101
+ case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
102
+ case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
103
+ case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
104
+ case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
105
+ default: break; \
106
+ } \
107
+ } \
108
+ else { \
109
+ switch (octx->op) { \
110
+ case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
111
+ case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
112
+ case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
113
+ case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
114
+ default: break; \
115
+ } \
94
116
  }
95
117
 
96
- uint64_t t1, t2;
97
- t1 = HAP_perf_get_qtimer_count();
98
-
99
- int is_aligned = 1;
100
- int opt_path = 0;
101
- if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
102
- (0 == htp_is_aligned((void *) dst->data, VLEN))) {
103
- FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
104
- is_aligned = 0;
118
+ // Macro for vector op switch (All Aligned)
119
+ #define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
120
+ if(TYPE == HTP_TYPE_F32) { \
121
+ switch (octx->op) { \
122
+ case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
123
+ case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
124
+ case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
125
+ case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
126
+ default: break; \
127
+ } \
128
+ } \
129
+ else { \
130
+ switch (octx->op) { \
131
+ case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
132
+ case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
133
+ case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
134
+ case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
135
+ default: break; \
136
+ } \
105
137
  }
106
- if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
107
- opt_path = 1;
138
+
139
+ // Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
140
+ #define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
141
+ if(TYPE == HTP_TYPE_F32) { \
142
+ switch (octx->op) { \
143
+ case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
144
+ case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
145
+ case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
146
+ case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
147
+ default: break; \
148
+ } \
149
+ } \
150
+ else { \
151
+ switch (octx->op) { \
152
+ case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
153
+ case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
154
+ case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
155
+ case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
156
+ default: break; \
157
+ } \
108
158
  }
109
159
 
110
- hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
160
+ // Macro for vector op switch (All Unaligned - generic loop used in element repeat)
161
+ #define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
162
+ if(TYPE == HTP_TYPE_F32) { \
163
+ switch (octx->op) { \
164
+ case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
165
+ case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
166
+ case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
167
+ case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
168
+ default: break; \
169
+ } \
170
+ } \
171
+ else { \
172
+ switch (octx->op) { \
173
+ case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
174
+ case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
175
+ case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
176
+ case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
177
+ default: break; \
178
+ } \
179
+ }
111
180
 
112
- uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
181
+ // 1. Scalar src1 (ne10 == 1)
182
+ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
183
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
184
+ struct htp_ops_context * octx = bctx->octx;
185
+ htp_binary_preamble;
113
186
 
114
- const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
115
- uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
187
+ const uint32_t src0_type = octx->src0.type;
188
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
189
+ const uint32_t total_rows = ne01 * ne02 * ne03;
190
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
191
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
192
+ if (start_row >= end_row) return;
193
+
194
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
195
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
196
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
197
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
198
+
199
+ dma_queue * q = octx->ctx->dma[ith];
200
+ uint32_t ir_prefetch = start_row;
201
+ int spad_idx = 0;
202
+
203
+ // Preamble
204
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
205
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
206
+ uint32_t i03, i02, i01, rem;
207
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
208
+ rem = ir_prefetch - i03 * (ne02 * ne01);
209
+ i02 = fastdiv(rem, &bctx->dim1_div);
210
+ i01 = rem - i02 * ne01;
211
+
212
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
213
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
214
+
215
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
216
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
217
+
218
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
219
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
220
+ ir_prefetch += current_block_size;
221
+ spad_idx ^= 1;
222
+ }
116
223
 
117
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
224
+ // Main loop
225
+ for (uint32_t ir = start_row; ir < end_row; ) {
226
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
118
227
 
119
- const uint32_t ne02_ne01 = ne02 * ne01;
228
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
229
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
120
230
 
121
- for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
122
- const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
123
- const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
124
- const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
231
+ uint32_t i03, i02, i01, rem;
232
+ i03 = fastdiv(ir, &bctx->dim12_div);
233
+ rem = ir - i03 * (ne02 * ne01);
234
+ i02 = fastdiv(rem, &bctx->dim1_div);
235
+ i01 = rem - i02 * ne01;
125
236
 
126
- const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
127
- const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
128
- const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
237
+ // src1 indices (broadcast/repeat)
238
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
239
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
240
+ uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
129
241
 
130
- const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
242
+ uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
243
+ uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
131
244
 
132
- if (ir + 1 < src0_end_row) {
133
- htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
134
- if (src1_row_size == src0_row_size) {
135
- htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
136
- }
245
+ for (uint32_t r = 0; r < current_block_size; r++) {
246
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
247
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
248
+ COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
249
+ src1_ptr += s1_stride;
137
250
  }
138
251
 
139
- const uint32_t nr0 = ne00 / ne10;
140
- if (nr0 > 1) {
141
- if ((1 == is_aligned) && (nr0 == ne00)) {
142
- hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
143
- } else {
144
- for (uint32_t r = 0; r < nr0; r++) {
145
- memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
146
- }
147
- }
148
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
149
- } else {
150
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
252
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
253
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
254
+
255
+ if (ir_prefetch < end_row) {
256
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
257
+ uint32_t p03, p02, p01, prem;
258
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
259
+ prem = ir_prefetch - p03 * (ne02 * ne01);
260
+ p02 = fastdiv(prem, &bctx->dim1_div);
261
+ p01 = prem - p02 * ne01;
262
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
263
+
264
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
265
+ ir_prefetch += next_block_size;
151
266
  }
267
+ ir += current_block_size;
268
+ }
269
+ dma_queue_flush(q);
270
+ }
152
271
 
153
- src0_ptr += src0_row_size;
154
- dst_ptr += dst_row_size;
272
+ // 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
273
+ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
274
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
275
+ struct htp_ops_context * octx = bctx->octx;
276
+ htp_binary_preamble;
277
+
278
+ const uint32_t src0_type = octx->src0.type;
279
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
280
+ const uint32_t total_rows = ne01 * ne02 * ne03;
281
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
282
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
283
+ if (start_row >= end_row) return;
284
+
285
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
286
+ uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
287
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
288
+
289
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
290
+ size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
291
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
292
+
293
+ dma_queue * q = octx->ctx->dma[ith];
294
+ uint32_t ir_prefetch = start_row;
295
+ int spad_idx = 0;
296
+
297
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
298
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
299
+ uint32_t i03, i02, i01, rem;
300
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
301
+ rem = ir_prefetch - i03 * (ne02 * ne01);
302
+ i02 = fastdiv(rem, &bctx->dim1_div);
303
+ i01 = rem - i02 * ne01;
304
+
305
+ uint32_t i13 = (ne13 == 1) ? 0 : i03;
306
+ uint32_t i12 = (ne12 == 1) ? 0 : i02;
307
+ uint32_t i11 = (ne11 == 1) ? 0 : i01;
308
+
309
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
310
+ uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
311
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
312
+
313
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
314
+ uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
315
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
316
+
317
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
318
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
319
+ dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
320
+ ir_prefetch += current_block_size;
321
+ spad_idx ^= 1;
155
322
  }
156
323
 
157
- t2 = HAP_perf_get_qtimer_count();
324
+ for (uint32_t ir = start_row; ir < end_row; ) {
325
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
326
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
327
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
328
+ uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
329
+
330
+ for (uint32_t r = 0; r < current_block_size; r++) {
331
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
332
+ uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
333
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
334
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
335
+ }
158
336
 
159
- FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
160
- ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
161
- (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
337
+ uint32_t i03, i02, i01, rem;
338
+ i03 = fastdiv(ir, &bctx->dim12_div);
339
+ rem = ir - i03 * (ne02 * ne01);
340
+ i02 = fastdiv(rem, &bctx->dim1_div);
341
+ i01 = rem - i02 * ne01;
342
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
343
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
344
+
345
+ if (ir_prefetch < end_row) {
346
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
347
+ uint32_t p03, p02, p01, prem;
348
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
349
+ prem = ir_prefetch - p03 * (ne02 * ne01);
350
+ p02 = fastdiv(prem, &bctx->dim1_div);
351
+ p01 = prem - p02 * ne01;
352
+
353
+ uint32_t p13 = (ne13 == 1) ? 0 : p03;
354
+ uint32_t p12 = (ne12 == 1) ? 0 : p02;
355
+ uint32_t p11 = (ne11 == 1) ? 0 : p01;
356
+
357
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
358
+ uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
359
+
360
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
361
+ dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
362
+
363
+ ir_prefetch += next_block_size;
364
+ }
365
+ ir += current_block_size;
366
+ }
367
+ dma_queue_flush(q);
162
368
  }
163
369
 
164
- static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
165
- uint8_t * spad_data,
166
- uint32_t nth,
167
- uint32_t ith,
168
- hvx_elemwise_f32_func func_HVX) {
370
+ // 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
371
+ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
372
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
373
+ struct htp_ops_context * octx = bctx->octx;
169
374
  htp_binary_preamble;
170
375
 
171
- const size_t src0_row_size = nb01;
172
- const size_t src1_row_size = nb11;
173
- const size_t dst_row_size = nb1;
376
+ const uint32_t src0_type = octx->src0.type;
377
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
378
+ const uint32_t total_rows = ne01 * ne02 * ne03;
379
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
380
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
381
+ if (start_row >= end_row) return;
174
382
 
175
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
383
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
384
+ uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
385
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
176
386
 
177
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
178
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
387
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
388
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
179
389
 
180
- // no work for this thread
181
- if (src0_start_row >= src0_end_row) {
182
- return;
183
- }
390
+ dma_queue * q = octx->ctx->dma[ith];
391
+ uint32_t ir_prefetch = start_row;
392
+ int spad_idx = 0;
393
+
394
+ void * s1_ptr = (void *) src1_spad;
395
+
396
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
397
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
398
+ uint32_t i03, i02, i01, rem;
399
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
400
+ rem = ir_prefetch - i03 * (ne02 * ne01);
401
+ i02 = fastdiv(rem, &bctx->dim1_div);
402
+ i01 = rem - i02 * ne01;
184
403
 
185
- uint64_t t1, t2;
186
- t1 = HAP_perf_get_qtimer_count();
404
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
405
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
187
406
 
188
- if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
189
- (0 == htp_is_aligned((void *) dst->data, VLEN))) {
190
- FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
407
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
408
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
409
+
410
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
411
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
412
+ ir_prefetch += current_block_size;
413
+ spad_idx ^= 1;
191
414
  }
192
415
 
193
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
194
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
195
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
196
-
197
- const uint32_t ne02_ne01 = ne02 * ne01;
198
- for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
199
- // src0 indices
200
- const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
201
- const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
202
- const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
203
-
204
- // src1 indices
205
- const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
206
- assert(i11 >= 0 && i11 < ne11);
207
-
208
- float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
209
- const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
210
- const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
211
-
212
- if (ir + 1 < src0_end_row) {
213
- htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
214
- if (src1_row_size == src0_row_size) {
215
- htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
216
- }
416
+ for (uint32_t ir = start_row; ir < end_row; ) {
417
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
418
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
419
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
420
+
421
+ for (uint32_t r = 0; r < current_block_size; r++) {
422
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
423
+ uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
424
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
425
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
217
426
  }
218
427
 
219
- const uint32_t nr0 = ne00 / ne10;
220
- if (nr0 > 1) {
221
- for (uint32_t r = 0; r < nr0; r++) {
222
- memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
223
- }
224
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
225
- } else {
226
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
428
+ uint32_t i03, i02, i01, rem;
429
+ i03 = fastdiv(ir, &bctx->dim12_div);
430
+ rem = ir - i03 * (ne02 * ne01);
431
+ i02 = fastdiv(rem, &bctx->dim1_div);
432
+ i01 = rem - i02 * ne01;
433
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
434
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
435
+
436
+ if (ir_prefetch < end_row) {
437
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
438
+ uint32_t p03, p02, p01, prem;
439
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
440
+ prem = ir_prefetch - p03 * (ne02 * ne01);
441
+ p02 = fastdiv(prem, &bctx->dim1_div);
442
+ p01 = prem - p02 * ne01;
443
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
444
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
445
+ ir_prefetch += next_block_size;
227
446
  }
447
+ ir += current_block_size;
448
+ }
449
+ dma_queue_flush(q);
450
+ }
451
+
452
+ // 4. Vector Complex (ne10 == ne00, complex broadcast)
453
+ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
454
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
455
+ struct htp_ops_context * octx = bctx->octx;
456
+ htp_binary_preamble;
457
+
458
+ const uint32_t src0_type = octx->src0.type;
459
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
460
+ const uint32_t total_rows = ne01 * ne02 * ne03;
461
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
462
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
463
+ if (start_row >= end_row) return;
464
+
465
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
466
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
467
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
468
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
469
+
470
+ dma_queue * q = octx->ctx->dma[ith];
471
+ uint32_t ir_prefetch = start_row;
472
+ int spad_idx = 0;
473
+
474
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
475
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
476
+ uint32_t i03, i02, i01, rem;
477
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
478
+ rem = ir_prefetch - i03 * (ne02 * ne01);
479
+ i02 = fastdiv(rem, &bctx->dim1_div);
480
+ i01 = rem - i02 * ne01;
481
+
482
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
483
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
484
+
485
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
486
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
487
+
488
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
489
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
490
+ ir_prefetch += current_block_size;
491
+ spad_idx ^= 1;
228
492
  }
229
493
 
230
- t2 = HAP_perf_get_qtimer_count();
494
+ for (uint32_t ir = start_row; ir < end_row; ) {
495
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
496
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
497
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
498
+
499
+ uint32_t i03, i02, i01, rem;
500
+ i03 = fastdiv(ir, &bctx->dim12_div);
501
+ rem = ir - i03 * (ne02 * ne01);
502
+ i02 = fastdiv(rem, &bctx->dim1_div);
503
+ i01 = rem - i02 * ne01;
504
+
505
+ for (uint32_t r = 0; r < current_block_size; r++) {
506
+ uint32_t r_i01 = i01 + r;
507
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
508
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
509
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
510
+
511
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
512
+ uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
513
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
514
+
515
+ // Read src1 from DDR (unaligned)
516
+ COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
517
+ }
231
518
 
232
- FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
233
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
234
- src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
235
- dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
519
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
520
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
521
+
522
+ if (ir_prefetch < end_row) {
523
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
524
+ uint32_t p03, p02, p01, prem;
525
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
526
+ prem = ir_prefetch - p03 * (ne02 * ne01);
527
+ p02 = fastdiv(prem, &bctx->dim1_div);
528
+ p01 = prem - p02 * ne01;
529
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
530
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
531
+ ir_prefetch += next_block_size;
532
+ }
533
+ ir += current_block_size;
534
+ }
535
+ dma_queue_flush(q);
236
536
  }
237
537
 
238
- static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
239
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
538
+ // 5. Element Repeat (ne10 != ne00)
539
+ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
540
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
541
+ struct htp_ops_context * octx = bctx->octx;
542
+ htp_binary_preamble;
240
543
 
241
- switch (octx->op) {
242
- case HTP_OP_MUL:
243
- case HTP_OP_ADD:
244
- case HTP_OP_SUB:
245
- binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
246
- break;
544
+ const uint32_t src0_type = octx->src0.type;
545
+ const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
546
+ const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
547
+ const uint32_t total_rows = ne01 * ne02 * ne03;
548
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
549
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
550
+ if (start_row >= end_row) return;
551
+
552
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
553
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
554
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
555
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
556
+
557
+ dma_queue * q = octx->ctx->dma[ith];
558
+ uint32_t ir_prefetch = start_row;
559
+ int spad_idx = 0;
560
+
561
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
562
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
563
+ uint32_t i03, i02, i01, rem;
564
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
565
+ rem = ir_prefetch - i03 * (ne02 * ne01);
566
+ i02 = fastdiv(rem, &bctx->dim1_div);
567
+ i01 = rem - i02 * ne01;
568
+
569
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
570
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
571
+
572
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
573
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
574
+
575
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
576
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
577
+ ir_prefetch += current_block_size;
578
+ spad_idx ^= 1;
579
+ }
247
580
 
248
- case HTP_OP_ADD_ID:
249
- binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
250
- break;
581
+ for (uint32_t ir = start_row; ir < end_row; ) {
582
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
583
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
584
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
585
+
586
+ uint32_t i03, i02, i01, rem;
587
+ i03 = fastdiv(ir, &bctx->dim12_div);
588
+ rem = ir - i03 * (ne02 * ne01);
589
+ i02 = fastdiv(rem, &bctx->dim1_div);
590
+ i01 = rem - i02 * ne01;
591
+
592
+ for (uint32_t r = 0; r < current_block_size; r++) {
593
+ uint32_t r_i01 = i01 + r;
594
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
595
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
596
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
597
+
598
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
599
+ uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
600
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
601
+
602
+ // Repeat src1 row
603
+ for (uint32_t c = 0; c < ne00; c += ne10) {
604
+ uint32_t len = MIN(ne10, ne00 - c);
605
+ // Use UUU for speed and simplicity
606
+ COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
607
+ }
608
+ }
251
609
 
252
- default:
253
- FARF(ERROR, "Unknown Binary Op %u", octx->op);
254
- break;
610
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
611
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
612
+
613
+ if (ir_prefetch < end_row) {
614
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
615
+ uint32_t p03, p02, p01, prem;
616
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
617
+ prem = ir_prefetch - p03 * (ne02 * ne01);
618
+ p02 = fastdiv(prem, &bctx->dim1_div);
619
+ p01 = prem - p02 * ne01;
620
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
621
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
622
+ ir_prefetch += next_block_size;
623
+ }
624
+ ir += current_block_size;
255
625
  }
626
+ dma_queue_flush(q);
256
627
  }
257
628
 
258
- static int execute_op_binary_f32(struct htp_ops_context * octx) {
259
- int err = HTP_STATUS_OK;
629
+ // 6. ADD_ID (src1 gathered via src2 indices)
630
+ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
631
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
632
+ struct htp_ops_context * octx = bctx->octx;
260
633
 
261
634
  const struct htp_tensor * src0 = &octx->src0;
262
635
  const struct htp_tensor * src1 = &octx->src1;
636
+ const struct htp_tensor * src2 = &octx->src2;
263
637
  struct htp_tensor * dst = &octx->dst;
264
638
 
265
- worker_callback_t binary_op_func;
266
- const char * op_type = NULL;
267
-
268
- switch (octx->op) {
269
- case HTP_OP_MUL:
270
- binary_op_func = binary_job_dispatcher_f32;
271
- op_type = "mul-f32";
272
- break;
273
-
274
- case HTP_OP_ADD:
275
- binary_op_func = binary_job_dispatcher_f32;
276
- op_type = "add-f32";
277
- break;
278
-
279
- case HTP_OP_SUB:
280
- binary_op_func = binary_job_dispatcher_f32;
281
- op_type = "sub-f32";
282
- break;
283
-
284
- case HTP_OP_ADD_ID:
285
- binary_op_func = binary_job_dispatcher_f32;
286
- op_type = "add-id-f32";
287
- break;
288
-
289
- default:
290
- FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
291
- return HTP_STATUS_NO_SUPPORT;
639
+ const uint32_t ne00 = src0->ne[0];
640
+ const uint32_t ne01 = src0->ne[1];
641
+ const uint32_t ne02 = src0->ne[2];
642
+ const uint32_t ne03 = src0->ne[3];
643
+ const uint32_t ne11 = src1->ne[1]; // for bounds check
644
+
645
+ const uint32_t nb01 = src0->nb[1];
646
+ const uint32_t nb02 = src0->nb[2];
647
+ const uint32_t nb03 = src0->nb[3];
648
+ const uint32_t nb11 = src1->nb[1]; // src1 row stride
649
+ const uint32_t nb1 = dst->nb[1];
650
+ const uint32_t nb2 = dst->nb[2];
651
+ const uint32_t nb3 = dst->nb[3];
652
+
653
+ const uint32_t total_rows = ne01 * ne02 * ne03;
654
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
655
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
656
+ if (start_row >= end_row) return;
657
+
658
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
659
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
660
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
661
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
662
+
663
+ dma_queue * q = octx->ctx->dma[ith];
664
+ uint32_t ir_prefetch = start_row;
665
+ int spad_idx = 0;
666
+
667
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
668
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
669
+ uint32_t i03, i02, i01, rem;
670
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
671
+ rem = ir_prefetch - i03 * (ne02 * ne01);
672
+ i02 = fastdiv(rem, &bctx->dim1_div);
673
+ i01 = rem - i02 * ne01;
674
+
675
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
676
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
677
+
678
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
679
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
680
+
681
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
682
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
683
+ ir_prefetch += current_block_size;
684
+ spad_idx ^= 1;
292
685
  }
293
686
 
294
- const int n_threads = octx->n_threads;
687
+ for (uint32_t ir = start_row; ir < end_row; ) {
688
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
689
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
690
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
691
+
692
+ uint32_t i03, i02, i01, rem;
693
+ i03 = fastdiv(ir, &bctx->dim12_div);
694
+ rem = ir - i03 * (ne02 * ne01);
695
+ i02 = fastdiv(rem, &bctx->dim1_div);
696
+ i01 = rem - i02 * ne01;
697
+
698
+ for (uint32_t r = 0; r < current_block_size; r++) {
699
+ uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
700
+
701
+ const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
702
+
703
+ uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
704
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
705
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
706
+
707
+ hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
708
+ }
709
+
710
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
711
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
712
+
713
+ if (ir_prefetch < end_row) {
714
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
715
+ uint32_t p03, p02, p01, prem;
716
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
717
+ prem = ir_prefetch - p03 * (ne02 * ne01);
718
+ p02 = fastdiv(prem, &bctx->dim1_div);
719
+ p01 = prem - p02 * ne01;
720
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
721
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
722
+ ir_prefetch += next_block_size;
723
+ }
724
+ ir += current_block_size;
725
+ }
726
+ dma_queue_flush(q);
727
+ }
728
+
729
+ static int execute_op_binary(struct htp_ops_context * octx) {
730
+ const struct htp_tensor * src0 = &octx->src0;
731
+ const struct htp_tensor * src1 = &octx->src1;
732
+ struct htp_tensor * dst = &octx->dst;
733
+
295
734
  const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
735
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
736
+
737
+ // Use packed row sizes for VTCM allocation
738
+ const uint32_t src0_type = octx->src0.type;
739
+ const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
740
+ const size_t src0_row_size = src0->ne[0] * elem_size;
741
+ const size_t src1_row_size = src1->ne[0] * elem_size;
742
+ const size_t dst_row_size = dst->ne[0] * elem_size;
743
+
744
+ // Align to VLEN
745
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
746
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
747
+ size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
748
+
749
+ bool is_add_id = (octx->op == HTP_OP_ADD_ID);
750
+ bool is_scalar = !is_add_id && (src1->ne[0] == 1);
751
+
752
+ // Determine which kernel we will use to alloc memory and dispatch
753
+ bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
754
+ (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
755
+ (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
756
+ (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
757
+
758
+ bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
759
+ bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
760
+ bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
761
+
762
+ size_t spad_row_total;
763
+ if (is_scalar) {
764
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
765
+ } else if (is_row_bcast) {
766
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
767
+ } else if (use_vector_same) {
768
+ spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
769
+ } else if (is_add_id) {
770
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
771
+ } else {
772
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
773
+ }
296
774
 
297
- const size_t src0_row_size = src0->nb[1];
298
- const size_t src1_row_size = src1->nb[1];
299
- const size_t dst_row_size = dst->nb[1];
775
+ size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
776
+ // Adjust for static src1 in row_bcast case
777
+ if (is_row_bcast) {
778
+ size_t needed_static = src1_row_size_aligned;
779
+ if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
780
+ size_t avail = octx->ctx->vtcm_size - needed_static;
781
+ rows_per_buffer = avail / (n_threads * spad_row_total);
782
+ }
300
783
 
301
- // VTCM scratchpads for all tensors
302
- octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
303
- octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
304
- octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
784
+ if (rows_per_buffer < 1) {
785
+ FARF(ERROR, "binary: VTCM too small\n");
786
+ return HTP_STATUS_VTCM_TOO_SMALL;
787
+ }
305
788
 
306
- size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
789
+ octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
790
+ octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
307
791
 
308
- FARF(HIGH,
309
- "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
310
- op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
311
- src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
312
- octx->dst_spad.size);
792
+ if (is_scalar || use_complex || use_repeat || is_add_id) {
793
+ octx->src1_spad.size_per_thread = 0;
794
+ } else if (is_row_bcast) {
795
+ octx->src1_spad.size_per_thread = 0;
796
+ } else {
797
+ octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
798
+ }
313
799
 
314
- // Make sure the reserved vtcm size is sufficient
315
- if (octx->ctx->vtcm_size < spad_size) {
316
- FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
317
- octx->ctx->vtcm_size, spad_size);
800
+ octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
801
+ if (is_row_bcast) {
802
+ octx->src1_spad.size = src1_row_size_aligned;
803
+ } else {
804
+ octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
805
+ }
806
+ octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
807
+
808
+ if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
318
809
  return HTP_STATUS_VTCM_TOO_SMALL;
319
810
  }
320
811
 
@@ -322,39 +813,79 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
322
813
  octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
323
814
  octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
324
815
 
325
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
326
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
816
+ if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
817
+ return HTP_STATUS_OK;
818
+ }
819
+
820
+ dma_queue * q = octx->ctx->dma[0];
821
+ if (is_row_bcast) {
822
+ dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
823
+ }
824
+
825
+ struct htp_binary_context bctx;
826
+ bctx.octx = octx;
827
+ bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
828
+ bctx.block_max = rows_per_buffer;
829
+ bctx.src0_row_size_aligned = src0_row_size_aligned;
830
+ bctx.src1_row_size_aligned = src1_row_size_aligned;
831
+ bctx.dst_row_size_aligned = dst_row_size_aligned;
327
832
 
328
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
833
+ bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
834
+ bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
835
+ bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
329
836
 
330
- octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
331
- octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
332
- octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
333
- octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
837
+ bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
838
+ bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
839
+ bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
334
840
 
335
- octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
336
- octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
337
- octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
338
- octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
841
+ bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
842
+ bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
339
843
 
340
- worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
844
+ bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
845
+ bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
846
+
847
+ bctx.split_at_ne01 = (src0->ne[2] > 1) &&
848
+ ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
849
+
850
+ bctx.split_at_ne02 = (src0->ne[3] > 1) &&
851
+ ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
852
+
853
+ // Precompute specific kernel parameters
854
+ if (use_vector_same) {
855
+ bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
856
+ bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
857
+ }
858
+
859
+ worker_callback_t worker_func;
860
+ if (is_add_id) worker_func = binary_job_add_id;
861
+ else if (is_scalar) worker_func = binary_job_scalar;
862
+ else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
863
+ else if (use_vector_same) worker_func = binary_job_vector_same_shape;
864
+ else if (use_complex) worker_func = binary_job_vector_complex;
865
+ else worker_func = binary_job_element_repeat;
866
+
867
+ if (is_row_bcast) {
868
+ dma_queue_pop(q);
341
869
  }
342
870
 
343
- return err;
871
+ worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
872
+
873
+ return HTP_STATUS_OK;
344
874
  }
345
875
 
346
876
  int op_binary(struct htp_ops_context * octx) {
347
- int err = HTP_STATUS_OK;
348
877
 
349
- switch (octx->src0.type) {
350
- case HTP_TYPE_F32:
351
- err = execute_op_binary_f32(octx);
352
- break;
878
+ // Does not support permutations of src1
879
+ const struct htp_tensor * src1 = &octx->src1;
880
+ if (src1->nb[1] < src1->nb[0]) {
881
+ return HTP_STATUS_NO_SUPPORT;
882
+ }
353
883
 
354
- default:
355
- err = HTP_STATUS_NO_SUPPORT;
356
- break;
884
+ const uint32_t src0_type = octx->src0.type;
885
+ if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
886
+ return execute_op_binary(octx);
357
887
  }
358
888
 
359
- return err;
889
+ return HTP_STATUS_NO_SUPPORT;
360
890
  }
891
+