whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -2,166 +2,299 @@
2
2
  #pragma clang diagnostic ignored "-Wunused-function"
3
3
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
4
 
5
- #ifdef HTP_DEBUG
6
- # define FARF_HIGH 1
7
- #endif
5
+ #include <assert.h>
8
6
  #include <HAP_farf.h>
9
- #include <HAP_mem.h>
10
7
  #include <HAP_perf.h>
11
- #include <hexagon_protos.h>
12
- #include <hexagon_types.h>
13
8
  #include <math.h>
14
9
  #include <string.h>
15
10
 
11
+ #include "hex-dma.h"
12
+ #include "hvx-utils.h"
13
+ #include "hvx-dump.h"
14
+ #include "hvx-flash-attn.h"
15
+
16
16
  #define GGML_COMMON_DECL_C
17
17
  #include "ggml-common.h"
18
18
  #include "htp-ctx.h"
19
- #include "htp-dma.h"
20
- #include "htp-msg.h"
21
19
  #include "htp-ops.h"
22
- #include "hvx-utils.h"
23
- #include "ops-utils.h"
20
+ #include "htp-ops.h"
21
+ #include "hmx-ops.h"
22
+
23
+ // Must be multiple of 32
24
+ #define FLASH_ATTN_BLOCK_SIZE (32 * 2)
25
+
26
+ #if __HVX_ARCH__ < 79
27
+ #define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
28
+ #define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
29
+ #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
30
+ #else
31
+ #define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
32
+ #define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
33
+ #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
34
+ #endif
24
35
 
25
- // Dot product of FP32 and FP16 vectors, accumulating to float
26
- static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
27
- const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
36
+ // This is a bit of a hack because the compiler is strugling to properly inline
37
+ // the default hvx_vec_f32_to_f16 with output into the local array.
38
+ static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
39
+ {
40
+ *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
41
+ }
42
+
43
+ // Dot product of two F16 vectors, accumulating to float
44
+ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
28
45
  const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
46
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
29
47
 
30
48
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
31
49
  uint32_t nloe = n % VLEN_FP16; // leftover elements
32
50
 
33
- const HVX_Vector zero = Q6_V_vsplat_R(0);
34
- HVX_Vector rsum = Q6_V_vsplat_R(0);
51
+ HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
35
52
 
36
53
  uint32_t i = 0;
37
54
 
38
55
  #pragma unroll(4)
39
56
  for (i = 0; i < nvec; i++) {
40
- // Load y (fp32) and convert into fp16
41
- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
42
- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
43
- HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
44
-
45
- // Load x (fp16)
46
- HVX_Vector x_hf = vx[i];
57
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
58
+ }
47
59
 
48
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
60
+ if (nloe) {
61
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
62
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
63
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
49
64
 
50
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
65
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
51
66
  }
52
67
 
53
- if (nloe) {
54
- // Load y (fp32) and convert into fp16
55
- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
56
- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
57
- HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
68
+ HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
69
+ rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
70
+ hvx_vec_store_u(r, 4, rsum);
71
+ }
58
72
 
59
- // Load x (fp16)
60
- HVX_Vector x_hf = vx[i];
73
+ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
74
+ const uint8_t * restrict x,
75
+ const size_t stride_x,
76
+ const size_t nvec,
77
+ const size_t nloe) {
78
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
79
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
80
+ const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
81
+ const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
82
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
83
+
84
+ HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
85
+ HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
86
+ HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
87
+ HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
61
88
 
62
- // Zero-out unused elements
63
- // Note that we need to clear both x and y because they may contain NANs
64
- HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
65
- x_hf = Q6_V_vand_QV(bmask, x_hf);
66
- y_hf = Q6_V_vand_QV(bmask, y_hf);
89
+ uint32_t i = 0;
67
90
 
68
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
91
+ for (i = 0; i < nvec; i++) {
92
+ HVX_Vector y_hf = vy[i];
93
+ HVX_Vector x0_hf = vx0[i];
94
+ HVX_Vector x1_hf = vx1[i];
95
+ HVX_Vector x2_hf = vx2[i];
96
+ HVX_Vector x3_hf = vx3[i];
97
+
98
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
99
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
100
+ rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
101
+ rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
102
+ }
69
103
 
70
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
104
+ if (nloe) {
105
+ // Load x (fp16) and zero-out unused elements
106
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
107
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
108
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
109
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
110
+ HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
111
+ HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
112
+
113
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
114
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
115
+ rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
116
+ rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
71
117
  }
72
118
 
73
- rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
74
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
119
+ HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
120
+ HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
121
+ HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
122
+ HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
75
123
 
76
- hvx_vec_store_u(r, 4, rsum);
124
+ HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
125
+ return hvx_vec_reduce_sum_f32x4(rsum0123);
77
126
  }
78
127
 
79
- // Dot product of two F16 vectors, accumulating to float
80
- static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
81
- const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
82
- const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
128
+ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
129
+ const uint8_t * restrict x,
130
+ const size_t stride_x,
131
+ const size_t n,
132
+ float s) {
133
+
134
+ const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
135
+ const size_t nloe = n % VLEN_FP16; // leftover elements
136
+
137
+ HVX_Vector sums = Q6_V_vzero();
138
+ const size_t stride_x_4 = stride_x * 4;
139
+ for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
140
+ HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
141
+ HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
142
+ sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
143
+ x += stride_x_4;
144
+ }
145
+
146
+ return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
147
+ }
148
+
149
+ // MAD: y (F32) += x (F16) * s (F16)
150
+ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) {
151
+ const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
152
+
153
+ HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
154
+ HVX_Vector * restrict vy = (HVX_Vector *) y;
83
155
 
84
156
  uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
85
157
  uint32_t nloe = n % VLEN_FP16; // leftover elements
86
158
 
87
- const HVX_Vector zero = Q6_V_vsplat_R(0);
88
- HVX_Vector rsum = Q6_V_vsplat_R(0);
159
+ HVX_Vector S0 = hvx_vec_splat_f16(*s);
89
160
 
90
161
  uint32_t i = 0;
91
162
 
92
- #pragma unroll(4)
93
- for (i = 0; i < nvec; i++) {
94
- HVX_Vector y_hf = vy[i];
95
- HVX_Vector x_hf = vx[i];
96
-
97
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
98
-
99
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
163
+ #pragma unroll(2)
164
+ for (i = 0; i < nvec; ++i) {
165
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
100
166
  }
101
167
 
102
168
  if (nloe) {
103
- HVX_Vector y_hf = vy[i];
169
+ HVX_VectorPair xy_p = vy_p[i];
170
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
104
171
 
105
- // Load x (fp16) and zero-out unused elements
106
- HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
107
- HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
172
+ HVX_Vector xy = Q6_V_lo_W(xy_p);
173
+ i = 2 * i; // index for vy
108
174
 
109
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
175
+ if (nloe >= VLEN_FP32) {
176
+ vy[i] = xy;
177
+ nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
178
+ }
110
179
 
111
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
180
+ if (nloe) {
181
+ hvx_vec_store_a(&vy[i], nloe * 4, xy);
182
+ }
112
183
  }
113
-
114
- rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
115
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
116
- hvx_vec_store_u(r, 4, rsum);
117
184
  }
118
185
 
119
- // MAD: y (F32) += x (F16) * v (float)
120
- static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
121
- const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
122
- HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
186
+ // MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
187
+ static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
188
+ const __fp16 * restrict s0, const __fp16 * restrict s1, int n) {
189
+ const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
190
+ const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
123
191
 
124
- uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
125
- uint32_t nloe = n % VLEN_FP16; // leftover elements
192
+ HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
193
+ HVX_Vector * restrict vy = (HVX_Vector *) y;
126
194
 
127
- HVX_Vector S = hvx_vec_splat_fp16(s);
195
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
196
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
197
+
198
+ HVX_Vector S0 = hvx_vec_splat_f16(*s0);
199
+ HVX_Vector S1 = hvx_vec_splat_f16(*s1);
128
200
 
129
201
  uint32_t i = 0;
130
- #pragma unroll(4)
202
+
203
+ #pragma unroll(2)
131
204
  for (i = 0; i < nvec; ++i) {
132
- // Multiply x * s -> pair of F32 vectors
133
- HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
134
- ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
135
- ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
205
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
206
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
136
207
  }
137
208
 
138
209
  if (nloe) {
139
- HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
210
+ HVX_VectorPair xy_p = vy_p[i];
211
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
212
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
140
213
 
141
- HVX_Vector xs = Q6_V_lo_W(xs_p);
142
- i = 2 * i; // index for ptr_y
214
+ HVX_Vector xy = Q6_V_lo_W(xy_p);
215
+ i = 2 * i; // index for vy
143
216
 
144
- if (nloe >= 32) {
145
- ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
146
- nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
217
+ if (nloe >= VLEN_FP32) {
218
+ vy[i] = xy;
219
+ nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
147
220
  }
148
221
 
149
222
  if (nloe) {
150
- HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
151
- hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
223
+ hvx_vec_store_a(&vy[i], nloe * 4, xy);
152
224
  }
153
225
  }
154
226
  }
155
227
 
156
- #define FLASH_ATTN_BLOCK_SIZE 128
228
+ struct htp_fa_context {
229
+ const struct htp_ops_context * octx;
230
+
231
+ struct fastdiv_values src0_div21;
232
+ struct fastdiv_values src0_div1;
233
+
234
+ struct fastdiv_values broadcast_rk2;
235
+ struct fastdiv_values broadcast_rk3;
236
+ struct fastdiv_values broadcast_rv2;
237
+ struct fastdiv_values broadcast_rv3;
238
+
239
+ struct fastdiv_values src3_div2;
240
+ struct fastdiv_values src3_div3;
241
+
242
+ float scale;
243
+ float max_bias;
244
+ float logit_softcap;
245
+
246
+ uint32_t n_head_log2;
247
+ float m0;
248
+ float m1;
249
+ float slopes[512];
250
+
251
+ uint32_t n_blocks;
252
+
253
+ size_t size_q_row_padded;
254
+ size_t size_k_row_padded;
255
+ size_t size_v_row_padded;
256
+
257
+ size_t size_k_block;
258
+ size_t size_v_block;
259
+ size_t size_m_block;
260
+
261
+ uint32_t qrows;
262
+ uint32_t qrows_per_thread;
263
+
264
+ bool is_q_fp32;
265
+
266
+ uint64_t t_start;
267
+ };
268
+
269
+ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
270
+ assert((size_t) dst % 128 == 0);
271
+ assert((size_t) src % 128 == 0);
157
272
 
158
- static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
159
- const struct htp_tensor * q = &octx->src0;
160
- const struct htp_tensor * k = &octx->src1;
161
- const struct htp_tensor * v = &octx->src2;
162
- const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
163
- const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
164
- struct htp_tensor * dst = &octx->dst;
273
+ const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
274
+ HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
275
+
276
+ const uint32_t nvec = n / VLEN_FP32;
277
+ const uint32_t nloe = n % VLEN_FP32;
278
+
279
+ uint32_t i = 0;
280
+ #pragma unroll(4)
281
+ for (; i < nvec; ++i) {
282
+ vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
283
+ }
284
+ if (nloe) {
285
+ hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
286
+ }
287
+ }
288
+
289
+ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
290
+ struct htp_fa_context * factx = (struct htp_fa_context *) data;
291
+ const struct htp_ops_context * octx = factx->octx;
292
+ const struct htp_tensor * q = octx->src[0];
293
+ const struct htp_tensor * k = octx->src[1];
294
+ const struct htp_tensor * v = octx->src[2];
295
+ const struct htp_tensor * mask = octx->src[3];
296
+ const struct htp_tensor * sinks = octx->src[4];
297
+ const struct htp_tensor * dst = octx->dst;
165
298
 
166
299
  const uint32_t neq0 = q->ne[0];
167
300
  const uint32_t neq1 = q->ne[1];
@@ -198,22 +331,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
198
331
  const uint32_t nb2 = dst->nb[2];
199
332
  const uint32_t nb3 = dst->nb[3];
200
333
 
201
- float scale = 1.0f;
202
- float max_bias = 0.0f;
203
- float logit_softcap = 0.0f;
204
-
205
- memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
206
- memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
207
- memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
208
-
209
- if (logit_softcap != 0) {
210
- scale /= logit_softcap;
211
- }
212
-
213
334
  // total rows in q
214
- const uint32_t nr = neq1*neq2*neq3;
215
-
216
- const uint32_t dr = (nr + nth - 1) / nth;
335
+ const uint32_t nr = factx->qrows;
336
+ const uint32_t dr = factx->qrows_per_thread;
217
337
  const uint32_t ir0 = dr * ith;
218
338
  const uint32_t ir1 = MIN(ir0 + dr, nr);
219
339
 
@@ -225,18 +345,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
225
345
  const uint32_t DV = nev0;
226
346
 
227
347
  const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
228
- const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
229
-
230
348
  const size_t size_k_row = DK * sizeof(__fp16);
231
349
  const size_t size_v_row = DV * sizeof(__fp16);
232
- const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
233
-
234
- const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
235
- const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
236
-
237
- const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
238
- const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
239
- const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
240
350
 
241
351
  // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
242
352
  uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -245,72 +355,81 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
245
355
  uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
246
356
  uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
247
357
 
248
- const uint32_t n_head = neq2;
249
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
250
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
251
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
358
+ const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
359
+
360
+ dma_cache m_cache;
361
+ dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);
252
362
 
253
363
  for (uint32_t ir = ir0; ir < ir1; ++ir) {
254
- const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
255
- const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
364
+ const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
365
+ const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
256
366
  const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
257
367
 
258
- const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
259
- const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
368
+ const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
369
+ const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
260
370
 
261
- const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
262
- const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
371
+ const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
372
+ const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
263
373
 
264
374
  // Fetch Q row
265
375
  const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
266
- dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
267
-
268
- const uint32_t h = iq2; // head index
269
- const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
376
+ dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
270
377
 
271
- float S = 0.0f; // sum
272
- float M = -INFINITY; // maximum KQ value
273
-
274
- // Clear accumulator
275
- float * VKQ32 = (float *) spad_a;
276
- memset(VKQ32, 0, DV * sizeof(float));
378
+ // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row,
379
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
277
380
 
278
381
  const __fp16 * mp_base = NULL;
279
382
  if (mask) {
280
- const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
281
- const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
383
+ const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
384
+ const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
282
385
  mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
283
386
  }
284
387
 
285
- const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
286
-
287
388
  // Prefetch first two blocks
288
- for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
389
+ for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
289
390
  const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
290
391
  const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
291
392
 
292
393
  // K
293
394
  const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
294
- uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
295
- dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
395
+ uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
396
+ dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
296
397
 
297
398
  // V
298
399
  const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
299
- uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
300
- dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
400
+ uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
401
+ dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
301
402
 
302
403
  // Mask
303
404
  if (mask) {
304
405
  const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
305
- uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
306
406
  // Mask is 1D contiguous for this row
307
- dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
407
+ dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
308
408
  }
409
+
410
+ // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
411
+ // ith, ir, ib, iq1, iq2, iq3,
412
+ // size_k_row, size_v_row, current_block_size,
413
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
309
414
  }
310
415
 
311
- const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
416
+ const uint32_t h = iq2; // head index
417
+ const float slope = factx->slopes[h];
312
418
 
313
- for (uint32_t ib = 0; ib < n_blocks; ++ib) {
419
+ HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
420
+ HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
421
+
422
+ // Clear accumulator
423
+ hvx_splat_f32_a(spad_a, 0, DV);
424
+ float * VKQ32 = (float *) (spad_a + 0);
425
+
426
+ uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
427
+ if (factx->is_q_fp32) {
428
+ hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
429
+ }
430
+
431
+ const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
432
+ for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
314
433
  const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
315
434
  const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
316
435
 
@@ -319,156 +438,162 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
319
438
  uint8_t * v_base = dma_queue_pop(dma).dst; // V
320
439
  __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
321
440
 
441
+ // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u",
442
+ // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm,
443
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
444
+
322
445
  // Inner loop processing the block from VTCM
323
446
  uint32_t ic = 0;
324
447
 
325
- // Process in blocks of 32 (VLEN_FP32)
326
- for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
448
+ // Process in sub-blocks of 32 (VLEN_FP32)
449
+ HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
450
+ HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
451
+ for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) {
327
452
  // 1. Compute scores
328
- float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
329
- for (int j = 0; j < VLEN_FP32; ++j) {
330
- const uint32_t cur_ic = ic + j;
331
- const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
332
- if (q->type == HTP_TYPE_F32) {
333
- hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
334
- } else {
335
- hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
336
- }
337
- }
338
-
339
- HVX_Vector scores = *(HVX_Vector *) scores_arr;
453
+ HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
340
454
 
341
455
  // 2. Softcap
342
- if (logit_softcap != 0.0f) {
343
- scores = hvx_vec_tanh_fp32(scores);
344
- scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
345
- scores = Q6_Vsf_equals_Vqf32(scores);
456
+ if (factx->logit_softcap != 0.0f) {
457
+ scores = hvx_vec_tanh_f32(scores);
458
+ scores = HVX_OP_MUL_F32(scores, logit_cap);
346
459
  }
347
460
 
348
461
  // 3. Mask
349
462
  if (mask) {
350
463
  const __fp16 * mp = m_base + ic;
351
- HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
352
-
353
- HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
354
- HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
355
-
356
- HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
357
-
358
- HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
359
- HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
360
- scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
361
- scores = Q6_Vsf_equals_Vqf32(scores);
464
+ HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
465
+
466
+ // Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79.
467
+ // Clamp -INFINITY to the max negative fp16 finite value (-65504.0f).
468
+ HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00);
469
+ HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF);
470
+ HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf);
471
+ m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16);
472
+
473
+ #if __HVX_ARCH__ >= 79
474
+ HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
475
+ HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
476
+ scores = Q6_Vsf_vadd_VsfVsf(add_val, scores);
477
+ #else
478
+ HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
479
+ HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
480
+ scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores));
481
+ #endif
362
482
  }
363
483
 
364
- // 4. Online Softmax Update
365
- HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
366
- float m_block = hvx_vec_get_fp32(v_max);
367
-
368
- float M_old = M;
369
- float M_new = (m_block > M) ? m_block : M;
370
- M = M_new;
371
-
372
- float ms = expf(M_old - M_new);
373
-
374
- hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
375
- S = S * ms;
376
-
377
- HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
378
- HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
379
- HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
380
-
381
- HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
382
- float p_sum = hvx_vec_get_fp32(p_sum_vec);
383
- S += p_sum;
384
-
385
- // 5. Accumulate V
386
- float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
387
- *(HVX_Vector*)p_arr = P;
388
-
389
- for (int j = 0; j < VLEN_FP32; ++j) {
390
- const uint32_t cur_ic = ic + j;
391
- const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
392
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
484
+ // Mask out invalid lanes for leftover handling
485
+ uint32_t valid_lanes = current_block_size - ic;
486
+ if (valid_lanes < VLEN_FP32) {
487
+ HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane
488
+ scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY));
393
489
  }
394
- }
395
490
 
396
- // Leftover
397
- for (; ic < current_block_size; ++ic) {
398
- float s_val;
399
- const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
400
-
401
- if (q->type == HTP_TYPE_F32) {
402
- hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
403
- } else {
404
- hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
405
- }
406
-
407
- if (logit_softcap != 0.0f) {
408
- s_val = logit_softcap * tanhf(s_val);
409
- }
410
-
411
- if (mask) {
412
- const float m_val = m_base[ic];
413
- s_val += slope * m_val;
414
- }
415
-
416
- const float Mold = M;
417
- float ms = 1.0f;
418
- float vs = 1.0f;
491
+ sb_scores[iv] = scores;
492
+ v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
493
+ }
419
494
 
420
- if (s_val > M) {
421
- M = s_val;
422
- ms = expf(Mold - M);
423
- hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
424
- } else {
425
- vs = expf(s_val - M);
495
+ {
496
+ // 4. Online Softmax Update
497
+ HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
498
+ HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec);
499
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
500
+ M_vec = M_new_vec;
501
+
502
+ hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
503
+
504
+ HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
505
+ for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) {
506
+ HVX_Vector scores = sb_scores[iv];
507
+ HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec);
508
+ HVX_Vector P = hvx_vec_exp_f32(scores_shifted);
509
+
510
+ p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P);
511
+
512
+ // 5. Accumulate V
513
+ __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
514
+ hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
515
+
516
+ float __attribute__((aligned(128))) P_arr[VLEN_FP32];
517
+ hvx_vec_store_a(P_arr, 128, P);
518
+
519
+ for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
520
+ const uint32_t cur_ic = ic2 + j;
521
+ if (cur_ic >= current_block_size) {
522
+ break;
523
+ }
524
+
525
+ if (cur_ic + 1 == current_block_size) {
526
+ // Odd leftover, process single row
527
+ if (P_arr[j] != 0.0f) {
528
+ const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
529
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV);
530
+ }
531
+ break;
532
+ }
533
+
534
+ // Avoid NaN * 0.0 = NaN for uninitialized V cache rows.
535
+ // Check the f32 values to safely avoid strict aliasing violations.
536
+ if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) {
537
+ continue;
538
+ }
539
+
540
+ const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
541
+ hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
542
+ }
426
543
  }
427
544
 
428
- const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
429
-
430
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
431
-
432
- S = S * ms + vs;
545
+ p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
546
+ S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec);
433
547
  }
434
548
 
435
549
  // Issue DMA for next+1 block (if exists)
436
- if (ib + 2 < n_blocks) {
550
+ if (ib + 2 < factx->n_blocks) {
437
551
  const uint32_t next_ib = ib + 2;
438
552
  const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
439
553
  const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
440
554
 
441
555
  // K
442
556
  const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
443
- dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
557
+ dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
444
558
 
445
559
  // V
446
560
  const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
447
- dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
561
+ dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
448
562
 
449
563
  // Mask
450
564
  if (mask) {
451
565
  const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
452
- dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
566
+ dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
453
567
  }
568
+
569
+ // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
570
+ // ith, ir, next_ib, iq1, iq2, iq3,
571
+ // size_k_row, size_v_row, next_block_size,
572
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
454
573
  }
455
574
  }
456
575
 
457
576
  // sinks
577
+ float M = hvx_vec_get_f32(M_vec);
578
+ float S = hvx_vec_get_f32(S_vec);
579
+
458
580
  if (sinks) {
459
581
  const float s = ((float *)((char *) sinks->data))[h];
460
582
 
461
- float ms = 1.0f;
462
583
  float vs = 1.0f;
463
584
 
464
585
  if (s > M) {
465
- ms = expf(M - s);
466
- hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
586
+ HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
587
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
588
+ hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
589
+
590
+ float ms = hvx_vec_get_f32(ms_vec);
591
+ S = S * ms + vs;
467
592
  } else {
468
- vs = expf(s - M);
593
+ HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
594
+ vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
595
+ S += vs;
469
596
  }
470
-
471
- S = S * ms + vs;
472
597
  }
473
598
 
474
599
  const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -480,64 +605,114 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
480
605
  const int i2 = iq2;
481
606
  const int i3 = iq3;
482
607
 
483
- // dst is permuted
484
- uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
608
+ // dst is permuted: [DV, n_heads, n_tokens, n_seq]
609
+ // head stride is nb[1], token stride is nb[2], batch stride is nb[3]
610
+ uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3];
485
611
 
486
612
  if (dst->type == HTP_TYPE_F32) {
487
- hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
613
+ hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
488
614
  } else if (dst->type == HTP_TYPE_F16) {
489
- hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
615
+ hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
490
616
  }
491
617
  }
492
618
  }
493
619
 
494
- static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
495
- struct htp_ops_context * octx = data;
496
- flash_attn_ext_f16_thread(octx, i, n);
497
- }
498
-
499
620
  int op_flash_attn_ext(struct htp_ops_context * octx) {
500
- const struct htp_tensor * q = &octx->src0;
501
- const struct htp_tensor * k = &octx->src1;
502
- const struct htp_tensor * v = &octx->src2;
503
- const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
504
- struct htp_tensor * dst = &octx->dst;
621
+ const struct htp_tensor * q = octx->src[0];
622
+ const struct htp_tensor * k = octx->src[1];
623
+ const struct htp_tensor * v = octx->src[2];
624
+ const struct htp_tensor * mask = octx->src[3];
625
+ const struct htp_tensor * dst = octx->dst;
505
626
 
506
627
  // Check support
507
- if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
508
- k->type != HTP_TYPE_F16 ||
509
- v->type != HTP_TYPE_F16) {
628
+ if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
510
629
  return HTP_STATUS_NO_SUPPORT;
511
630
  }
512
631
 
513
- octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
514
- octx->src0_div1 = init_fastdiv_values(q->ne[1]);
632
+ #ifdef HTP_HAS_HMX
633
+ // HMX path: head_dim multiple of 64, F16 KV, and no sinks
634
+ if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
635
+ int ret = hmx_flash_attn_ext(octx);
636
+ if (ret == HTP_STATUS_OK) {
637
+ return ret;
638
+ }
639
+ // VTCM too small or other failure -> fall through to HVX path
640
+ }
641
+ #endif
642
+
643
+ struct htp_fa_context factx;
644
+ factx.octx = octx;
645
+
646
+ factx.t_start = HAP_perf_get_qtimer_count();
515
647
 
516
- octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
517
- octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
518
- octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
519
- octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
648
+ factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
649
+ factx.src0_div1 = init_fastdiv_values(q->ne[1]);
650
+
651
+ factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
652
+ factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
653
+ factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
654
+ factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
520
655
 
521
656
  if (mask) {
522
- octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
523
- octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
657
+ factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
658
+ factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
659
+ }
660
+
661
+ factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
662
+ factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
663
+ factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
664
+ factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
665
+
666
+ size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
667
+ factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
668
+ factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
669
+ factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
670
+
671
+ factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
672
+
673
+ float scale = 1.0f;
674
+ float max_bias = 0.0f;
675
+ float logit_softcap = 0.0f;
676
+
677
+ memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
678
+ memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
679
+ memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
680
+
681
+ if (logit_softcap != 0.0f) {
682
+ scale /= logit_softcap;
524
683
  }
525
684
 
526
- size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
527
- size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
528
- size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
685
+ factx.scale = scale;
686
+ factx.max_bias = max_bias;
687
+ factx.logit_softcap = logit_softcap;
688
+
689
+ uint32_t n_head = q->ne[2];
690
+ factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
691
+ factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
692
+ factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
693
+
694
+ if (n_head > 512) {
695
+ return HTP_STATUS_NO_SUPPORT;
696
+ }
697
+ for (uint32_t h = 0; h < n_head; ++h) {
698
+ factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f;
699
+ }
700
+
701
+ // total rows in q
702
+ const uint32_t neq0 = q->ne[0];
703
+ const uint32_t neq1 = q->ne[1];
704
+ const uint32_t neq2 = q->ne[2];
705
+ const uint32_t neq3 = q->ne[3];
529
706
 
530
- size_t size_q_block = size_q_row_padded * 1; // single row for now
531
- size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
532
- size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
533
- size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
707
+ factx.qrows = neq1*neq2*neq3;
708
+ factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads;
534
709
 
535
- size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
710
+ size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
536
711
 
537
712
  octx->src0_spad.size_per_thread = size_q_block * 1;
538
- octx->src1_spad.size_per_thread = size_k_block * 2;
539
- octx->src2_spad.size_per_thread = size_v_block * 2;
540
- octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
713
+ octx->src1_spad.size_per_thread = factx.size_k_block * 2;
714
+ octx->src2_spad.size_per_thread = factx.size_v_block * 2;
715
+ octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
541
716
  octx->dst_spad.size_per_thread = size_vkq_acc;
542
717
 
543
718
  octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -552,14 +727,14 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
552
727
  return HTP_STATUS_VTCM_TOO_SMALL;
553
728
  }
554
729
 
555
- octx->src0_spad.data = octx->ctx->vtcm_base;
556
- octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
557
- octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
558
- octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
559
- octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
730
+ octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
731
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
732
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL;
733
+ octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL;
734
+ octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL;
560
735
 
561
736
  if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
562
- worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
737
+ worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
563
738
  }
564
739
 
565
740
  return HTP_STATUS_OK;