whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -2,27 +2,20 @@
2
2
  #pragma clang diagnostic ignored "-Wunused-function"
3
3
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
4
 
5
- #ifdef HTP_DEBUG
6
- # define FARF_HIGH 1
7
- #endif
8
5
  #include <HAP_farf.h>
9
- #include <HAP_mem.h>
10
6
  #include <HAP_perf.h>
11
- #include <HAP_ps.h>
12
- #include <hexagon_protos.h>
13
- #include <hexagon_types.h>
7
+
14
8
  #include <math.h>
15
- #include <qurt_thread.h>
16
9
  #include <string.h>
17
10
 
11
+ #include "hex-dma.h"
12
+ #include "hvx-utils.h"
13
+
18
14
  #define GGML_COMMON_DECL_C
19
15
  #include "ggml-common.h"
20
16
  #include "htp-ctx.h"
21
- #include "htp-dma.h"
22
17
  #include "htp-msg.h"
23
18
  #include "htp-ops.h"
24
- #include "hvx-utils.h"
25
- #include "ops-utils.h"
26
19
 
27
20
  #define htp_act_preamble3 \
28
21
  const uint32_t ne00 = src0->ne[0]; \
@@ -76,27 +69,45 @@
76
69
  const uint32_t nb2 = dst->nb[2]; \
77
70
  const uint32_t nb3 = dst->nb[3];
78
71
 
79
- static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
80
- const struct htp_tensor * src1,
81
- struct htp_tensor * dst,
82
- const int32_t * op_params,
83
- struct htp_spad * src0_spad,
84
- struct htp_spad * src1_spad,
85
- struct htp_spad * dst_spad,
86
- uint32_t nth,
87
- uint32_t ith,
88
- uint32_t src0_nrows_per_thread,
89
- dma_queue * dma_queue) {
72
+ struct htp_act_context {
73
+ struct htp_ops_context * octx;
74
+
75
+ // Precomputed values
76
+ const uint8_t * data_src0;
77
+ const uint8_t * data_src1;
78
+ uint8_t * data_dst;
79
+
80
+ size_t src0_row_size;
81
+ size_t src1_row_size;
82
+ size_t dst_row_size;
83
+
84
+ size_t src0_row_size_aligned;
85
+ size_t src1_row_size_aligned;
86
+ size_t dst_row_size_aligned;
87
+
88
+ size_t src0_spad_half_size;
89
+ size_t src1_spad_half_size;
90
+ size_t dst_spad_half_size;
91
+
92
+ uint32_t block;
93
+ uint32_t src0_nrows;
94
+ uint32_t src0_nrows_per_thread;
95
+ int nc;
96
+ };
97
+
98
+ static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
99
+ struct htp_act_context * actx = (struct htp_act_context *) data;
100
+ const struct htp_tensor * src0 = &actx->octx->src0;
101
+ const struct htp_tensor * src1 = &actx->octx->src1;
102
+ const struct htp_tensor * dst = &actx->octx->dst;
90
103
  htp_act_preamble3;
91
104
 
92
- size_t src0_row_size = nb01;
93
- size_t src1_row_size = nb11;
94
- size_t dst_row_size = nb1;
95
-
96
-
97
-
98
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
105
+ size_t src0_row_size = actx->src0_row_size;
106
+ size_t src1_row_size = actx->src1_row_size;
107
+ size_t dst_row_size = actx->dst_row_size;
99
108
 
109
+ const uint32_t src0_nrows = actx->src0_nrows;
110
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
100
111
  const uint32_t src0_start_row = src0_nrows_per_thread * ith;
101
112
  const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
102
113
 
@@ -108,43 +119,34 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
108
119
  uint64_t t1, t2;
109
120
  t1 = HAP_perf_get_qtimer_count();
110
121
 
111
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
112
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
113
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
122
+ const uint8_t * restrict data_src0 = actx->data_src0;
123
+ const uint8_t * restrict data_src1 = actx->data_src1;
124
+ uint8_t * restrict data_dst = actx->data_dst;
114
125
 
115
- const bool src1_valid = src1->ne[0];
116
- const int nc = (src1_valid) ? ne00 : ne00 / 2;
117
- if (!src1_valid) {
118
- const int32_t swapped = op_params[1];
119
- data_src1 = data_src0;
120
- src1_row_size = src0_row_size;
121
-
122
- const size_t nc_in_bytes = nc * SIZEOF_FP32;
123
- data_src0 += swapped ? nc_in_bytes : 0;
124
- data_src1 += swapped ? 0 : nc_in_bytes;
125
- }
126
+ const int nc = actx->nc;
126
127
 
127
- const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
128
- const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
129
- const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
128
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
129
+ const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
130
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
130
131
 
131
- uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
132
- uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
133
- uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
132
+ uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
133
+ uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
134
+ uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
134
135
 
135
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
136
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
137
- size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
138
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
136
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
137
+ size_t src1_spad_half_size = actx->src1_spad_half_size;
138
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
139
139
 
140
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
140
+ const int BLOCK = actx->block;
141
141
  if (BLOCK == 0) {
142
142
  FARF(ERROR,
143
143
  "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
144
- src0_spad->size_per_thread, src0_row_size_aligned);
144
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
145
145
  return;
146
146
  }
147
147
 
148
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
149
+
148
150
  // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
149
151
  for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
150
152
  const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -175,9 +177,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
175
177
  float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
176
178
 
177
179
  //swiglu(x) = x1 * sigmoid(x0)
178
- hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
179
- hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
180
- (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
180
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc);
181
+ hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
182
+ (const uint8_t *) src1_spad_ptr, nc);
181
183
  }
182
184
 
183
185
  dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
@@ -203,27 +205,22 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
203
205
  (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
204
206
  }
205
207
 
206
- static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
207
- const struct htp_tensor * src1,
208
- struct htp_tensor * dst,
209
- const int32_t * op_params,
210
- struct htp_spad * src0_spad,
211
- struct htp_spad * src1_spad,
212
- struct htp_spad * dst_spad,
213
- uint32_t nth,
214
- uint32_t ith,
215
- uint32_t src0_nrows_per_thread,
216
- dma_queue * dma_queue) {
208
+ static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
209
+ struct htp_act_context * actx = (struct htp_act_context *) data;
210
+ const struct htp_tensor * src0 = &actx->octx->src0;
211
+ const struct htp_tensor * src1 = &actx->octx->src1;
212
+ const struct htp_tensor * dst = &actx->octx->dst;
217
213
  htp_act_preamble3;
218
214
 
219
215
  uint64_t t1, t2;
220
216
  t1 = HAP_perf_get_qtimer_count();
221
217
 
222
- size_t src0_row_size = nb01;
223
- size_t src1_row_size = nb11;
224
- size_t dst_row_size = nb1;
218
+ size_t src0_row_size = actx->src0_row_size;
219
+ size_t src1_row_size = actx->src1_row_size;
220
+ size_t dst_row_size = actx->dst_row_size;
225
221
 
226
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
222
+ const uint32_t src0_nrows = actx->src0_nrows;
223
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
227
224
 
228
225
  const uint32_t src0_start_row = src0_nrows_per_thread * ith;
229
226
  const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -233,45 +230,36 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
233
230
  return;
234
231
  }
235
232
 
236
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
237
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
238
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
233
+ const uint8_t * restrict data_src0 = actx->data_src0;
234
+ const uint8_t * restrict data_src1 = actx->data_src1;
235
+ uint8_t * restrict data_dst = actx->data_dst;
239
236
 
240
- const bool src1_valid = src1->ne[0];
241
- const int nc = (src1_valid) ? ne00 : ne00 / 2;
242
- if (!src1_valid) {
243
- const int32_t swapped = op_params[1];
244
- data_src1 = data_src0;
245
- src1_row_size = src0_row_size;
237
+ const int nc = actx->nc;
246
238
 
247
- const size_t nc_in_bytes = nc * SIZEOF_FP32;
248
- data_src0 += swapped ? nc_in_bytes : 0;
249
- data_src1 += swapped ? 0 : nc_in_bytes;
250
- }
239
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
240
+ const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
241
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
251
242
 
252
- const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
253
- const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
254
- const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
243
+ uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
244
+ uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
245
+ uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
255
246
 
256
- uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
257
- uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
258
- uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
247
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
248
+ size_t src1_spad_half_size = actx->src1_spad_half_size;
249
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
259
250
 
260
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
261
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
262
- size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
263
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
264
-
265
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
251
+ const int BLOCK = actx->block;
266
252
  if (BLOCK == 0) {
267
253
  FARF(ERROR,
268
254
  "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
269
255
  "%zu\n",
270
- src0_spad->size_per_thread, src0_row_size_aligned);
256
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
271
257
  return;
272
258
  }
273
- const float alpha = ((const float *) (op_params))[2];
274
- const float limit = ((const float *) (op_params))[3];
259
+ const float alpha = ((const float *) (actx->octx->op_params))[2];
260
+ const float limit = ((const float *) (actx->octx->op_params))[3];
261
+
262
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
275
263
 
276
264
  // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
277
265
  for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
@@ -304,18 +292,18 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
304
292
  float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
305
293
 
306
294
  // x (src0_spad_data) = std::min(src0_p[k], limit);
307
- hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc);
295
+ hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc);
308
296
  // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
309
- hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc);
297
+ hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc);
310
298
  // y (src1_spad_data) = y1 + 1.f
311
- hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc);
299
+ hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc);
312
300
  // x1 (dst_spad_data) = alpha * (x)
313
- hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc);
301
+ hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc);
314
302
  // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
315
- hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
303
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);
316
304
  // out = x * sigmoid(alpha * x) * (y + 1.f)
317
- hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
318
- (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
305
+ hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
306
+ (const uint8_t *) src1_spad_ptr, nc);
319
307
  }
320
308
 
321
309
  dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
@@ -342,26 +330,22 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
342
330
  }
343
331
 
344
332
 
345
- static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
346
- struct htp_tensor * dst,
347
- const int32_t * op_params,
348
- struct htp_spad * src0_spad,
349
- struct htp_spad * dst_spad,
350
- uint32_t nth,
351
- uint32_t ith,
352
- uint32_t src0_nrows_per_thread,
353
- dma_queue * dma_queue) {
333
+ static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
334
+ struct htp_act_context * actx = (struct htp_act_context *) data;
335
+ const struct htp_tensor * src0 = &actx->octx->src0;
336
+ const struct htp_tensor * dst = &actx->octx->dst;
354
337
  htp_act_preamble2;
355
338
 
356
339
  uint64_t t1, t2;
357
340
  t1 = HAP_perf_get_qtimer_count();
358
341
 
359
- const size_t src0_row_size = nb01;
360
- const size_t dst_row_size = nb1;
361
- const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
362
- const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
342
+ const size_t src0_row_size = actx->src0_row_size;
343
+ const size_t dst_row_size = actx->dst_row_size;
344
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
345
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
363
346
 
364
- const uint32_t src0_nrows = ne01 * ne02 * ne03;
347
+ const uint32_t src0_nrows = actx->src0_nrows;
348
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
365
349
 
366
350
  const uint32_t src0_start_row = src0_nrows_per_thread * ith;
367
351
  const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -371,25 +355,29 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
371
355
  return;
372
356
  }
373
357
 
374
- const uint8_t * data_src0 = (const uint8_t *) src0->data;
375
- uint8_t * data_dst = (uint8_t *) dst->data;
358
+ const uint8_t * data_src0 = actx->data_src0;
359
+ uint8_t * data_dst = actx->data_dst;
360
+
361
+ // nc/ne0 matches.
362
+ const int ne0_val = actx->nc; // == dst->ne[0]
376
363
 
377
- uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
378
- uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
364
+ uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
365
+ uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
379
366
 
380
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
381
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
382
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
367
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
368
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
383
369
 
384
370
  // In gelu = x*sigmoid(x*1.702)
385
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
371
+ const int BLOCK = actx->block;
386
372
 
387
373
  if (BLOCK == 0) {
388
374
  FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
389
- src0_spad->size_per_thread, src0_row_size_aligned);
375
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
390
376
  return;
391
377
  }
392
378
 
379
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
380
+
393
381
  // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
394
382
  for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
395
383
  const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -415,9 +403,9 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
415
403
  float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
416
404
 
417
405
  // gelu = x * sigmoid(1.702 * x) // current implementation
418
- hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0);
419
- hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
420
- hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
406
+ hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
407
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
408
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
421
409
  }
422
410
 
423
411
  dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -442,34 +430,23 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
442
430
  ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
443
431
  }
444
432
 
445
- static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
446
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
447
- unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
448
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
449
- }
450
-
451
433
 
452
-
453
- static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
454
- struct htp_tensor * dst,
455
- const int32_t * op_params,
456
- struct htp_spad * src0_spad,
457
- struct htp_spad * dst_spad,
458
- uint32_t nth,
459
- uint32_t ith,
460
- uint32_t src0_nrows_per_thread,
461
- dma_queue * dma_queue) {
434
+ static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
435
+ struct htp_act_context * actx = (struct htp_act_context *) data;
436
+ const struct htp_tensor * src0 = &actx->octx->src0;
437
+ const struct htp_tensor * dst = &actx->octx->dst;
462
438
  htp_act_preamble2;
463
439
 
464
440
  uint64_t t1, t2;
465
441
  t1 = HAP_perf_get_qtimer_count();
466
442
 
467
- const size_t src0_row_size = nb01;
468
- const size_t dst_row_size = nb1;
469
- const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
470
- const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
443
+ const size_t src0_row_size = actx->src0_row_size;
444
+ const size_t dst_row_size = actx->dst_row_size;
445
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
446
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
471
447
 
472
- const uint32_t src0_nrows = ne01 * ne02 * ne03;
448
+ const uint32_t src0_nrows = actx->src0_nrows;
449
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
473
450
 
474
451
  const uint32_t src0_start_row = src0_nrows_per_thread * ith;
475
452
  const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -479,24 +456,27 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
479
456
  return;
480
457
  }
481
458
 
482
- const uint8_t * data_src0 = (const uint8_t *) src0->data;
483
- uint8_t * data_dst = (uint8_t *) dst->data;
459
+ const uint8_t * data_src0 = actx->data_src0;
460
+ uint8_t * data_dst = actx->data_dst;
484
461
 
485
- uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
486
- uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
462
+ const int ne0_val = actx->nc; // == dst->ne[0]
487
463
 
488
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
489
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
490
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
464
+ uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
465
+ uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
491
466
 
492
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
467
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
468
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
469
+
470
+ const int BLOCK = actx->block;
493
471
 
494
472
  if (BLOCK == 0) {
495
473
  FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
496
- src0_spad->size_per_thread, src0_row_size_aligned);
474
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
497
475
  return;
498
476
  }
499
477
 
478
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
479
+
500
480
  // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
501
481
  for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
502
482
  const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -522,8 +502,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
522
502
  float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
523
503
 
524
504
  // silu = x * sigmoid(x)
525
- hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
526
- hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
505
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
506
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
527
507
  }
528
508
 
529
509
  dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -548,27 +528,130 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
548
528
  ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
549
529
  }
550
530
 
551
- static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
552
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
553
- unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
554
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
555
- }
531
+ static const float GELU_COEF_A = 0.044715f;
532
+ static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
556
533
 
557
- static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
558
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
559
- glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
560
- &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
561
- }
534
+ static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
535
+ struct htp_act_context * actx = (struct htp_act_context *) data;
536
+ const struct htp_tensor * src0 = &actx->octx->src0;
537
+ const struct htp_tensor * src1 = &actx->octx->src1;
538
+ const struct htp_tensor * dst = &actx->octx->dst;
539
+ htp_act_preamble3;
562
540
 
563
- static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
564
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
565
- glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
566
- &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
567
- }
541
+ size_t src0_row_size = actx->src0_row_size;
542
+ size_t src1_row_size = actx->src1_row_size;
543
+ size_t dst_row_size = actx->dst_row_size;
568
544
 
569
- static int execute_op_activations_fp32(struct htp_ops_context * octx) {
570
- int err = HTP_STATUS_OK;
545
+ uint64_t t1, t2;
546
+ t1 = HAP_perf_get_qtimer_count();
547
+
548
+ const uint32_t src0_nrows = actx->src0_nrows;
549
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
550
+
551
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
552
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
553
+
554
+ // no work for this thread
555
+ if (src0_start_row >= src0_end_row) {
556
+ return;
557
+ }
558
+
559
+ const uint8_t * restrict data_src0 = actx->data_src0;
560
+ const uint8_t * restrict data_src1 = actx->data_src1;
561
+ uint8_t * restrict data_dst = actx->data_dst;
562
+
563
+ const int nc = actx->nc;
564
+
565
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
566
+ const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
567
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
568
+
569
+ uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
570
+ uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
571
+ uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
572
+
573
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
574
+ size_t src1_spad_half_size = actx->src1_spad_half_size;
575
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
576
+
577
+ const int BLOCK = actx->block;
578
+ if (BLOCK == 0) {
579
+ FARF(ERROR,
580
+ "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
581
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
582
+ return;
583
+ }
584
+
585
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
586
+
587
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
588
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
589
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
590
+
591
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
592
+ dma_queue_push_vtcm_to_ddr(dma_queue,
593
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
594
+ dst_row_size, dst_row_size_aligned, 0);
595
+
596
+ dma_queue_push_ddr_to_vtcm(dma_queue,
597
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
598
+ src0_row_size_aligned, src0_row_size, block_size);
599
+ dma_queue_push_ddr_to_vtcm(dma_queue,
600
+ dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
601
+ src1_row_size_aligned, src1_row_size, block_size);
602
+ }
603
+
604
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
605
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
606
+
607
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
608
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
609
+ float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
610
+
611
+ for (uint32_t ib = 0; ib < block_size; ib++) {
612
+ const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
613
+ const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
614
+ uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
615
+
616
+ // geglu tanh implementation
617
+ // geglu(x, g) = gelu(x) * g
618
+ // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
619
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
620
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
621
+ hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
622
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
623
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
624
+ hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
625
+ hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
626
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
627
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
628
+ hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
629
+ }
630
+
631
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
632
+ dst_row_size_aligned, block_size);
633
+
634
+ // prefetch N+2 loop iteration if any
635
+ const uint32_t pref_block = (ir + BLOCK * 2);
636
+ if (pref_block < src0_end_row) {
637
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
638
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
639
+ src0_row_size_aligned, src0_row_size, pref_block_size);
640
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
641
+ src1_row_size_aligned, src1_row_size, pref_block_size);
642
+ }
643
+ }
644
+
645
+ dma_queue_flush(dma_queue);
646
+
647
+ t2 = HAP_perf_get_qtimer_count();
648
+
649
+ FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
650
+ ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
651
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
652
+ }
571
653
 
654
+ static int execute_op_activations_f32(struct htp_ops_context * octx) {
572
655
  const struct htp_tensor * src0 = &octx->src0;
573
656
  const struct htp_tensor * src1 = &octx->src1;
574
657
  struct htp_tensor * dst = &octx->dst;
@@ -583,30 +666,35 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
583
666
 
584
667
  switch (octx->op) {
585
668
  case HTP_OP_UNARY_SILU:
586
- act_op_func = unary_silu_fp32;
669
+ act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
587
670
  op_type = "silu-f32";
588
671
  break;
589
672
 
590
673
  case HTP_OP_GLU_SWIGLU:
591
- act_op_func = glu_swiglu_fp32;
674
+ act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
592
675
  op_type = "swiglu-f32";
593
676
  break;
594
677
 
595
678
  case HTP_OP_GLU_SWIGLU_OAI:
596
- act_op_func = glu_swiglu_oai_fp32;
679
+ act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
597
680
  op_type = "swiglu-oai-f32";
598
681
  break;
599
682
  case HTP_OP_UNARY_GELU:
600
- act_op_func = unary_gelu_fp32;
683
+ act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
601
684
  op_type = "gelu-f32";
602
685
  break;
686
+
687
+ case HTP_OP_GLU_GEGLU:
688
+ act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
689
+ op_type = "geglu-f32";
690
+ break;
603
691
  default:
604
692
  FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
605
693
  return HTP_STATUS_NO_SUPPORT;
606
694
  }
607
695
 
608
- const uint32_t n_threads = octx->n_threads;
609
696
  const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
697
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
610
698
 
611
699
  size_t src0_row_size = src0->nb[1];
612
700
  size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
@@ -617,9 +705,9 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
617
705
  src1_row_size = src0_row_size;
618
706
  }
619
707
 
620
- const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
621
- const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
622
- const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
708
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
709
+ const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
710
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
623
711
  // VTCM scratchpads for all tensors
624
712
  // N rows per thread, padded to HVX vector size
625
713
 
@@ -656,13 +744,56 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
656
744
  octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
657
745
  }
658
746
 
659
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
660
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
661
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
662
- worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
747
+ if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
748
+ return HTP_STATUS_OK;
663
749
  }
664
750
 
665
- return err;
751
+ // Prepare context
752
+ struct htp_act_context actx;
753
+ actx.octx = octx;
754
+
755
+ actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
756
+
757
+ actx.src0_row_size = src0_row_size;
758
+ actx.src1_row_size = src1_row_size;
759
+ actx.dst_row_size = dst_row_size;
760
+
761
+ actx.src0_row_size_aligned = src0_row_size_aligned;
762
+ actx.src1_row_size_aligned = src1_row_size_aligned;
763
+ actx.dst_row_size_aligned = dst_row_size_aligned;
764
+
765
+ actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
766
+ actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
767
+ actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
768
+
769
+ actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
770
+ actx.src0_nrows = src0_nrows;
771
+
772
+ actx.nc = dst->ne[0];
773
+
774
+ // Pointers and GLU logic
775
+ const uint8_t * data_src0 = (const uint8_t *) src0->data;
776
+ const uint8_t * data_src1 = (const uint8_t *) src1->data;
777
+
778
+ if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
779
+ const int32_t swapped = octx->op_params[1];
780
+ data_src1 = data_src0;
781
+ actx.src1_row_size = actx.src0_row_size;
782
+
783
+ size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
784
+ if (swapped) {
785
+ data_src0 += nc_in_bytes;
786
+ } else {
787
+ data_src1 += nc_in_bytes;
788
+ }
789
+ }
790
+
791
+ actx.data_src0 = data_src0;
792
+ actx.data_src1 = data_src1;
793
+ actx.data_dst = (uint8_t *) dst->data;
794
+
795
+ worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);
796
+ return HTP_STATUS_OK;
666
797
  }
667
798
 
668
799
  int op_activations(struct htp_ops_context * octx) {
@@ -670,7 +801,7 @@ int op_activations(struct htp_ops_context * octx) {
670
801
 
671
802
  switch (octx->src0.type) {
672
803
  case HTP_TYPE_F32:
673
- err = execute_op_activations_fp32(octx);
804
+ err = execute_op_activations_f32(octx);
674
805
  break;
675
806
 
676
807
  default: