whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -0,0 +1,1148 @@
1
+ #include <math.h>
2
+ #include <stdint.h>
3
+ #include <string.h>
4
+
5
+ #include "hvx-utils.h"
6
+ #include "hex-fastdiv.h"
7
+
8
+ #define GGML_COMMON_DECL_C
9
+ #include "ggml-common.h"
10
+ #include "htp-ctx.h"
11
+
12
+ #ifndef MIN
13
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
14
+ #endif
15
+
16
+ #define HTP_GDN_MAX_SV 128
17
+
18
+
19
+ struct htp_gdn_context {
20
+ struct htp_ops_context * octx;
21
+ uint32_t rows_per_thread;
22
+ size_t state_bytes;
23
+ uint8_t * vtcm_base;
24
+ size_t vtcm_per_thread;
25
+ };
26
+
27
+ static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) {
28
+ HVX_Vector acc = Q6_V_vzero();
29
+
30
+ const uint32_t epv = 128 / sizeof(float);
31
+ const uint32_t nvec = n / epv;
32
+ const uint32_t nloe = n % epv;
33
+ for (uint32_t i = 0; i < nvec; ++i) {
34
+ HVX_Vector vd = hvx_vmemu(dst + i * epv);
35
+ HVX_Vector vm = hvx_vmem(mul + i * epv);
36
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
37
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
38
+ hvx_vmemu(dst + i * epv) = out;
39
+ acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
40
+ }
41
+
42
+ if (nloe) {
43
+ const uint32_t off = nvec * epv;
44
+ HVX_Vector vd = hvx_vmemu(dst + off);
45
+ HVX_Vector vm = hvx_vmem(mul + off);
46
+ HVX_Vector vdot = hvx_vmem(dot + off);
47
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
48
+ hvx_vec_store_u(dst + off, nloe * sizeof(float), out);
49
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
50
+ HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
51
+ acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
52
+ }
53
+
54
+ return hvx_vec_reduce_sum_f32(acc);
55
+ }
56
+
57
+ static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) {
58
+ HVX_Vector acc = Q6_V_vzero();
59
+ const HVX_Vector vmul = hvx_vec_splat_f32(mul);
60
+
61
+ const uint32_t epv = 128 / sizeof(float);
62
+ const uint32_t nvec = n / epv;
63
+ const uint32_t nloe = n % epv;
64
+ for (uint32_t i = 0; i < nvec; ++i) {
65
+ HVX_Vector vd = hvx_vmemu(dst + i * epv);
66
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
67
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
68
+ hvx_vmemu(dst + i * epv) = out;
69
+ acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
70
+ }
71
+
72
+ if (nloe) {
73
+ const uint32_t off = nvec * epv;
74
+ HVX_Vector vd = hvx_vmemu(dst + off);
75
+ HVX_Vector vdot = hvx_vmem(dot + off);
76
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
77
+ hvx_vec_store_u(dst + off, nloe * sizeof(float), out);
78
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
79
+ HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
80
+ acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
81
+ }
82
+
83
+ return hvx_vec_reduce_sum_f32(acc);
84
+ }
85
+
86
+ static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src,
87
+ HVX_Vector vscale, const float * restrict dot, uint32_t n) {
88
+ HVX_Vector acc = Q6_V_vzero();
89
+
90
+ const uint32_t epv = 128 / sizeof(float);
91
+ const uint32_t nvec = n / epv;
92
+ const uint32_t nloe = n % epv;
93
+ for (uint32_t i = 0; i < nvec; ++i) {
94
+ HVX_Vector vd = hvx_vmemu(dst + i * epv);
95
+ HVX_Vector vs = hvx_vmem(src + i * epv);
96
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
97
+ HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
98
+ hvx_vmemu(dst + i * epv) = out;
99
+ acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
100
+ }
101
+
102
+ if (nloe) {
103
+ const uint32_t off = nvec * epv;
104
+ HVX_Vector vd = hvx_vmemu(dst + off);
105
+ HVX_Vector vs = hvx_vmem(src + off);
106
+ HVX_Vector vdot = hvx_vmem(dot + off);
107
+ HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
108
+ hvx_vec_store_u(dst + off, nloe * sizeof(float), out);
109
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
110
+ HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
111
+ acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
112
+ }
113
+
114
+ return hvx_vec_reduce_sum_f32(acc);
115
+ }
116
+
117
+ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1,
118
+ float * restrict dst2, float * restrict dst3, const float * restrict mul,
119
+ const float * restrict dot, uint32_t n, float * restrict sums) {
120
+ HVX_Vector acc0 = Q6_V_vzero();
121
+ HVX_Vector acc1 = Q6_V_vzero();
122
+ HVX_Vector acc2 = Q6_V_vzero();
123
+ HVX_Vector acc3 = Q6_V_vzero();
124
+
125
+ const uint32_t epv = 128 / sizeof(float);
126
+ const uint32_t nvec = n / epv;
127
+ const uint32_t nloe = n % epv;
128
+ for (uint32_t i = 0; i < nvec; ++i) {
129
+ HVX_Vector vm = hvx_vmem(mul + i * epv);
130
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
131
+
132
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
133
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
134
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
135
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
136
+
137
+ hvx_vmemu(dst0 + i * epv) = out0;
138
+ hvx_vmemu(dst1 + i * epv) = out1;
139
+ hvx_vmemu(dst2 + i * epv) = out2;
140
+ hvx_vmemu(dst3 + i * epv) = out3;
141
+
142
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
143
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
144
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
145
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
146
+ }
147
+
148
+ if (nloe) {
149
+ const uint32_t off = nvec * epv;
150
+ HVX_Vector vm = hvx_vmem(mul + off);
151
+ HVX_Vector vdot = hvx_vmem(dot + off);
152
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
153
+ HVX_Vector zero = Q6_V_vzero();
154
+
155
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
156
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
157
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
158
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
159
+
160
+ hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
161
+ hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
162
+ hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
163
+ hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
164
+
165
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
166
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
167
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
168
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
169
+ }
170
+
171
+ HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
172
+ hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
173
+ }
174
+
175
+ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1,
176
+ float * restrict dst2, float * restrict dst3, float mul,
177
+ const float * restrict dot, uint32_t n, float * restrict sums) {
178
+ HVX_Vector acc0 = Q6_V_vzero();
179
+ HVX_Vector acc1 = Q6_V_vzero();
180
+ HVX_Vector acc2 = Q6_V_vzero();
181
+ HVX_Vector acc3 = Q6_V_vzero();
182
+ const HVX_Vector vmul = hvx_vec_splat_f32(mul);
183
+
184
+ const uint32_t epv = 128 / sizeof(float);
185
+ const uint32_t nvec = n / epv;
186
+ const uint32_t nloe = n % epv;
187
+ for (uint32_t i = 0; i < nvec; ++i) {
188
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
189
+
190
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
191
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
192
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
193
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
194
+
195
+ hvx_vmemu(dst0 + i * epv) = out0;
196
+ hvx_vmemu(dst1 + i * epv) = out1;
197
+ hvx_vmemu(dst2 + i * epv) = out2;
198
+ hvx_vmemu(dst3 + i * epv) = out3;
199
+
200
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
201
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
202
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
203
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
204
+ }
205
+
206
+ if (nloe) {
207
+ const uint32_t off = nvec * epv;
208
+ HVX_Vector vdot = hvx_vmem(dot + off);
209
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
210
+ HVX_Vector zero = Q6_V_vzero();
211
+
212
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
213
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
214
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
215
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
216
+
217
+ hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
218
+ hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
219
+ hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
220
+ hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
221
+
222
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
223
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
224
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
225
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
226
+ }
227
+
228
+ HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
229
+ hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
230
+ }
231
+
232
+ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1,
233
+ float * restrict dst2, float * restrict dst3, const float * restrict src,
234
+ const float * restrict scale, const float * restrict dot, uint32_t n,
235
+ float * restrict sums) {
236
+ HVX_Vector acc0 = Q6_V_vzero();
237
+ HVX_Vector acc1 = Q6_V_vzero();
238
+ HVX_Vector acc2 = Q6_V_vzero();
239
+ HVX_Vector acc3 = Q6_V_vzero();
240
+ const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
241
+ const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
242
+ const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
243
+ const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
244
+
245
+ const uint32_t epv = 128 / sizeof(float);
246
+ const uint32_t nvec = n / epv;
247
+ const uint32_t nloe = n % epv;
248
+ for (uint32_t i = 0; i < nvec; ++i) {
249
+ HVX_Vector vs = hvx_vmem(src + i * epv);
250
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
251
+
252
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
253
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
254
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
255
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
256
+
257
+ hvx_vmemu(dst0 + i * epv) = out0;
258
+ hvx_vmemu(dst1 + i * epv) = out1;
259
+ hvx_vmemu(dst2 + i * epv) = out2;
260
+ hvx_vmemu(dst3 + i * epv) = out3;
261
+
262
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
263
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
264
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
265
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
266
+ }
267
+
268
+ if (nloe) {
269
+ const uint32_t off = nvec * epv;
270
+ HVX_Vector vs = hvx_vmem(src + off);
271
+ HVX_Vector vdot = hvx_vmem(dot + off);
272
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
273
+ HVX_Vector zero = Q6_V_vzero();
274
+
275
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
276
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
277
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
278
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
279
+
280
+ hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
281
+ hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
282
+ hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
283
+ hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
284
+
285
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
286
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
287
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
288
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
289
+ }
290
+
291
+ HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
292
+ hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
293
+ }
294
+
295
+ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1,
296
+ float * restrict dst2, float * restrict dst3, float * restrict dst4,
297
+ float * restrict dst5, float * restrict dst6, float * restrict dst7,
298
+ const float * restrict mul, const float * restrict dot, uint32_t n,
299
+ float * restrict sums) {
300
+ HVX_Vector acc0 = Q6_V_vzero();
301
+ HVX_Vector acc1 = Q6_V_vzero();
302
+ HVX_Vector acc2 = Q6_V_vzero();
303
+ HVX_Vector acc3 = Q6_V_vzero();
304
+ HVX_Vector acc4 = Q6_V_vzero();
305
+ HVX_Vector acc5 = Q6_V_vzero();
306
+ HVX_Vector acc6 = Q6_V_vzero();
307
+ HVX_Vector acc7 = Q6_V_vzero();
308
+
309
+ const uint32_t epv = 128 / sizeof(float);
310
+ const uint32_t nvec = n / epv;
311
+ const uint32_t nloe = n % epv;
312
+ for (uint32_t i = 0; i < nvec; ++i) {
313
+ HVX_Vector vm = hvx_vmem(mul + i * epv);
314
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
315
+
316
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
317
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
318
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
319
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
320
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm);
321
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm);
322
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm);
323
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm);
324
+
325
+ hvx_vmemu(dst0 + i * epv) = out0;
326
+ hvx_vmemu(dst1 + i * epv) = out1;
327
+ hvx_vmemu(dst2 + i * epv) = out2;
328
+ hvx_vmemu(dst3 + i * epv) = out3;
329
+ hvx_vmemu(dst4 + i * epv) = out4;
330
+ hvx_vmemu(dst5 + i * epv) = out5;
331
+ hvx_vmemu(dst6 + i * epv) = out6;
332
+ hvx_vmemu(dst7 + i * epv) = out7;
333
+
334
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
335
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
336
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
337
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
338
+ acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
339
+ acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
340
+ acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
341
+ acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
342
+ }
343
+
344
+ if (nloe) {
345
+ const uint32_t off = nvec * epv;
346
+ HVX_Vector vm = hvx_vmem(mul + off);
347
+ HVX_Vector vdot = hvx_vmem(dot + off);
348
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
349
+ HVX_Vector zero = Q6_V_vzero();
350
+
351
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
352
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
353
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
354
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
355
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm);
356
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm);
357
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm);
358
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm);
359
+
360
+ hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
361
+ hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
362
+ hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
363
+ hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
364
+ hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4);
365
+ hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5);
366
+ hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6);
367
+ hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7);
368
+
369
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
370
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
371
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
372
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
373
+ acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
374
+ acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
375
+ acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
376
+ acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
377
+ }
378
+
379
+ HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
380
+ HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
381
+ hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
382
+ hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
383
+ }
384
+
385
+ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1,
386
+ float * restrict dst2, float * restrict dst3, float * restrict dst4,
387
+ float * restrict dst5, float * restrict dst6, float * restrict dst7,
388
+ float mul, const float * restrict dot, uint32_t n, float * restrict sums) {
389
+ HVX_Vector acc0 = Q6_V_vzero();
390
+ HVX_Vector acc1 = Q6_V_vzero();
391
+ HVX_Vector acc2 = Q6_V_vzero();
392
+ HVX_Vector acc3 = Q6_V_vzero();
393
+ HVX_Vector acc4 = Q6_V_vzero();
394
+ HVX_Vector acc5 = Q6_V_vzero();
395
+ HVX_Vector acc6 = Q6_V_vzero();
396
+ HVX_Vector acc7 = Q6_V_vzero();
397
+ const HVX_Vector vmul = hvx_vec_splat_f32(mul);
398
+
399
+ const uint32_t epv = 128 / sizeof(float);
400
+ const uint32_t nvec = n / epv;
401
+ const uint32_t nloe = n % epv;
402
+ for (uint32_t i = 0; i < nvec; ++i) {
403
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
404
+
405
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
406
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
407
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
408
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
409
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul);
410
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul);
411
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul);
412
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul);
413
+
414
+ hvx_vmemu(dst0 + i * epv) = out0;
415
+ hvx_vmemu(dst1 + i * epv) = out1;
416
+ hvx_vmemu(dst2 + i * epv) = out2;
417
+ hvx_vmemu(dst3 + i * epv) = out3;
418
+ hvx_vmemu(dst4 + i * epv) = out4;
419
+ hvx_vmemu(dst5 + i * epv) = out5;
420
+ hvx_vmemu(dst6 + i * epv) = out6;
421
+ hvx_vmemu(dst7 + i * epv) = out7;
422
+
423
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
424
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
425
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
426
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
427
+ acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
428
+ acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
429
+ acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
430
+ acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
431
+ }
432
+
433
+ if (nloe) {
434
+ const uint32_t off = nvec * epv;
435
+ HVX_Vector vdot = hvx_vmem(dot + off);
436
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
437
+ HVX_Vector zero = Q6_V_vzero();
438
+
439
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
440
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
441
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
442
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
443
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul);
444
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul);
445
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul);
446
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul);
447
+
448
+ hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
449
+ hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
450
+ hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
451
+ hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
452
+ hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4);
453
+ hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5);
454
+ hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6);
455
+ hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7);
456
+
457
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
458
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
459
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
460
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
461
+ acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
462
+ acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
463
+ acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
464
+ acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
465
+ }
466
+
467
+ HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
468
+ HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
469
+ hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
470
+ hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
471
+ }
472
+
473
+ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1,
474
+ float * restrict dst2, float * restrict dst3, float * restrict dst4,
475
+ float * restrict dst5, float * restrict dst6, float * restrict dst7,
476
+ const float * restrict src, const float * restrict scale,
477
+ const float * restrict dot, uint32_t n, float * restrict sums) {
478
+ HVX_Vector acc0 = Q6_V_vzero();
479
+ HVX_Vector acc1 = Q6_V_vzero();
480
+ HVX_Vector acc2 = Q6_V_vzero();
481
+ HVX_Vector acc3 = Q6_V_vzero();
482
+ HVX_Vector acc4 = Q6_V_vzero();
483
+ HVX_Vector acc5 = Q6_V_vzero();
484
+ HVX_Vector acc6 = Q6_V_vzero();
485
+ HVX_Vector acc7 = Q6_V_vzero();
486
+ const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
487
+ const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
488
+ const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
489
+ const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
490
+ const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]);
491
+ const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]);
492
+ const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]);
493
+ const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]);
494
+
495
+ const uint32_t epv = 128 / sizeof(float);
496
+ const uint32_t nvec = n / epv;
497
+ const uint32_t nloe = n % epv;
498
+ for (uint32_t i = 0; i < nvec; ++i) {
499
+ HVX_Vector vs = hvx_vmem(src + i * epv);
500
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
501
+
502
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
503
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
504
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
505
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
506
+ HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4));
507
+ HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5));
508
+ HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6));
509
+ HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7));
510
+
511
+ hvx_vmemu(dst0 + i * epv) = out0;
512
+ hvx_vmemu(dst1 + i * epv) = out1;
513
+ hvx_vmemu(dst2 + i * epv) = out2;
514
+ hvx_vmemu(dst3 + i * epv) = out3;
515
+ hvx_vmemu(dst4 + i * epv) = out4;
516
+ hvx_vmemu(dst5 + i * epv) = out5;
517
+ hvx_vmemu(dst6 + i * epv) = out6;
518
+ hvx_vmemu(dst7 + i * epv) = out7;
519
+
520
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
521
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
522
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
523
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
524
+ acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
525
+ acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
526
+ acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
527
+ acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
528
+ }
529
+
530
+ if (nloe) {
531
+ const uint32_t off = nvec * epv;
532
+ HVX_Vector vs = hvx_vmem(src + off);
533
+ HVX_Vector vdot = hvx_vmem(dot + off);
534
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float));
535
+ HVX_Vector zero = Q6_V_vzero();
536
+
537
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
538
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
539
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
540
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
541
+ HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4));
542
+ HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5));
543
+ HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6));
544
+ HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7));
545
+
546
+ hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0);
547
+ hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1);
548
+ hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2);
549
+ hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3);
550
+ hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4);
551
+ hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5);
552
+ hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6);
553
+ hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7);
554
+
555
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
556
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
557
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
558
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
559
+ acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
560
+ acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
561
+ acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
562
+ acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
563
+ }
564
+
565
+ HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
566
+ HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
567
+ hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
568
+ hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
569
+ }
570
+
571
+ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) {
572
+ struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
573
+ struct htp_ops_context * octx = gctx->octx;
574
+
575
+ const struct htp_tensor * q = octx->src[0];
576
+ const struct htp_tensor * k = octx->src[1];
577
+ const struct htp_tensor * v = octx->src[2];
578
+ const struct htp_tensor * g = octx->src[3];
579
+ const struct htp_tensor * beta = octx->src[4];
580
+ const struct htp_tensor * state = octx->src[5];
581
+ const struct htp_tensor * dst = octx->dst;
582
+
583
+ const uint32_t S_v = v->ne[0];
584
+ const uint32_t H = v->ne[1];
585
+ const uint32_t n_tokens = v->ne[2];
586
+ const uint32_t n_seqs = v->ne[3];
587
+ const uint32_t K = octx->op_params[0];
588
+
589
+ const uint32_t total_rows = H * n_seqs;
590
+ if (ith >= total_rows) {
591
+ return;
592
+ }
593
+
594
+ const uint32_t rq3 = n_seqs / q->ne[3];
595
+ const uint32_t rk3 = n_seqs / k->ne[3];
596
+ const float scale = 1.0f / sqrtf((float) S_v);
597
+
598
+ float * dst_base = (float *) (uintptr_t) dst->data;
599
+ float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs;
600
+ const float * state_in_base = (const float *) (uintptr_t) state->data;
601
+
602
+ const bool kda = (g->ne[0] == S_v);
603
+ float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
604
+ float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
605
+ float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
606
+ float local_sums[32] __attribute__((aligned(128)));
607
+
608
+ dma_queue * dma = octx->ctx->dma[ith];
609
+ size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
610
+ state_aligned = (state_aligned + 127) & ~(size_t)127;
611
+ float * s_work[2];
612
+ s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith);
613
+ s_work[1] = s_work[0] + state_aligned / sizeof(float);
614
+
615
+ struct fastdiv_values fd_H = init_fastdiv_values(H);
616
+ struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]);
617
+ struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]);
618
+ struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
619
+ struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
620
+
621
+ const uint64_t state_seq_stride = state->nb[3] / sizeof(float);
622
+ const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
623
+
624
+ uint32_t ir_prefetch = ith;
625
+ int spad_idx = 0;
626
+
627
+ // Prefetch preamble (up to 2 steps)
628
+ for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) {
629
+ const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
630
+ const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
631
+ const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
632
+ // final state lands in snapshot slot 0 (most-recent-first ordering)
633
+ float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
634
+
635
+ // Push dummy write-back
636
+ dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
637
+ S_v * sizeof(float), S_v * sizeof(float),
638
+ S_v * sizeof(float), 0);
639
+
640
+ // Push fetch
641
+ dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
642
+ S_v * sizeof(float), S_v * sizeof(float),
643
+ S_v * sizeof(float), S_v);
644
+
645
+ ir_prefetch += nth;
646
+ spad_idx ^= 1;
647
+ }
648
+
649
+ int curr_spad_idx = 0;
650
+ for (uint32_t ir = ith; ir < total_rows; ir += nth) {
651
+ dma_queue_pop(dma);
652
+ dma_queue_pop(dma);
653
+
654
+ float * s_work_curr = s_work[curr_spad_idx];
655
+
656
+ const uint32_t iv1 = fastmodulo(ir, H, &fd_H);
657
+ const uint32_t iv3 = fastdiv(ir, &fd_H);
658
+
659
+ const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1);
660
+ const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1);
661
+ const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
662
+ const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
663
+
664
+ // final state lands in snapshot slot 0 (most-recent-first ordering)
665
+ float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
666
+
667
+ float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
668
+
669
+ for (uint32_t t = 0; t < n_tokens; ++t) {
670
+ const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
671
+ (uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]);
672
+ const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
673
+ (uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]);
674
+ const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
675
+ (uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]);
676
+ const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
677
+ (uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]);
678
+ const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
679
+ (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]);
680
+
681
+ hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v);
682
+ hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v);
683
+
684
+ if (kda) {
685
+ hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
686
+
687
+ uint32_t j = 0;
688
+ for (; j + 8 <= S_v; j += 8) {
689
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
690
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
691
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
692
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
693
+ float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
694
+ float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
695
+ float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
696
+ float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
697
+ gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
698
+ local_gate, local_k, S_v, local_sums);
699
+
700
+ float local_delta_b[32] __attribute__((aligned(128)));
701
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
702
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
703
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
704
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
705
+
706
+ gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
707
+ local_k, local_delta_b, local_q, S_v, local_sums);
708
+
709
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
710
+ hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
711
+ }
712
+ for (; j + 4 <= S_v; j += 4) {
713
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
714
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
715
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
716
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
717
+ gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
718
+
719
+ float local_delta_b[32] __attribute__((aligned(128)));
720
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
721
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
722
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
723
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
724
+
725
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
726
+
727
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
728
+ hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
729
+ }
730
+ HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
731
+ for (; j < S_v; ++j) {
732
+ float * row = s_work_curr + (uint64_t) j * S_v;
733
+ HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
734
+ HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
735
+ HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
736
+ HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
737
+ attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
738
+ }
739
+ } else {
740
+ const float gate = expf(g_t[0]);
741
+ uint32_t j = 0;
742
+ for (; j + 8 <= S_v; j += 8) {
743
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
744
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
745
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
746
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
747
+ float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
748
+ float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
749
+ float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
750
+ float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
751
+ gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
752
+ gate, local_k, S_v, local_sums);
753
+
754
+ float local_delta_b[32] __attribute__((aligned(128)));
755
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
756
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
757
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
758
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
759
+
760
+ gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
761
+ local_k, local_delta_b, local_q, S_v, local_sums);
762
+
763
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
764
+ hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
765
+ }
766
+ for (; j + 4 <= S_v; j += 4) {
767
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
768
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
769
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
770
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
771
+ gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
772
+
773
+ float local_delta_b[32] __attribute__((aligned(128)));
774
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
775
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
776
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
777
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
778
+
779
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
780
+
781
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
782
+ hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
783
+ }
784
+ HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
785
+ for (; j < S_v; ++j) {
786
+ float * row = s_work_curr + (uint64_t) j * S_v;
787
+ HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
788
+ HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
789
+ HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
790
+ HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
791
+ attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
792
+ }
793
+ }
794
+
795
+ if (K > 1) {
796
+ // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
797
+ const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t;
798
+ if (target_slot >= 0 && target_slot < (int64_t) K) {
799
+ float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
800
+ if (curr_state_o != s_out) {
801
+ hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v);
802
+ }
803
+ }
804
+ }
805
+
806
+ attn_data += (uint64_t) S_v * H;
807
+ }
808
+
809
+ // Push real write-back
810
+ dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr),
811
+ S_v * sizeof(float), S_v * sizeof(float),
812
+ S_v * sizeof(float), S_v);
813
+
814
+ // Prefetch next block (if any)
815
+ if (ir_prefetch < total_rows) {
816
+ const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
817
+ const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
818
+ const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
819
+
820
+ dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
821
+ S_v * sizeof(float), S_v * sizeof(float),
822
+ S_v * sizeof(float), S_v);
823
+
824
+ ir_prefetch += nth;
825
+ spad_idx ^= 1;
826
+ }
827
+
828
+ curr_spad_idx ^= 1;
829
+ }
830
+ dma_queue_flush(dma);
831
+ }
832
+
833
+
834
+ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) {
835
+ struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
836
+ struct htp_ops_context * octx = gctx->octx;
837
+
838
+ const struct htp_tensor * q = octx->src[0];
839
+ const struct htp_tensor * k = octx->src[1];
840
+ const struct htp_tensor * v = octx->src[2];
841
+ const struct htp_tensor * g = octx->src[3];
842
+ const struct htp_tensor * beta = octx->src[4];
843
+ const struct htp_tensor * state = octx->src[5];
844
+ const struct htp_tensor * dst = octx->dst;
845
+
846
+ const uint32_t S_v = v->ne[0];
847
+ const uint32_t H = v->ne[1];
848
+ const uint32_t n_seqs = v->ne[3];
849
+
850
+ const uint32_t total_rows = H * n_seqs;
851
+ if (ith >= total_rows) {
852
+ return;
853
+ }
854
+
855
+ const uint32_t rq3 = n_seqs / q->ne[3];
856
+ const uint32_t rk3 = n_seqs / k->ne[3];
857
+ const float scale = 1.0f / sqrtf((float) S_v);
858
+
859
+ float * dst_base = (float *) (uintptr_t) dst->data;
860
+ float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs;
861
+ const float * state_in_base = (const float *) (uintptr_t) state->data;
862
+
863
+ const bool kda = (g->ne[0] == S_v);
864
+ float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
865
+ float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
866
+ float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
867
+ float local_sums[32] __attribute__((aligned(128)));
868
+
869
+ dma_queue * dma = octx->ctx->dma[ith];
870
+ size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
871
+ state_aligned = (state_aligned + 127) & ~(size_t)127;
872
+ float * s_work[2];
873
+ s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith);
874
+ s_work[1] = s_work[0] + state_aligned / sizeof(float);
875
+
876
+ struct fastdiv_values fd_H = init_fastdiv_values(H);
877
+ struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]);
878
+ struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]);
879
+ struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
880
+ struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
881
+
882
+ const uint64_t state_seq_stride = state->nb[3] / sizeof(float);
883
+
884
+ uint32_t ir_prefetch = ith;
885
+ int spad_idx = 0;
886
+
887
+ // Prefetch preamble (up to 2 steps)
888
+ for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) {
889
+ const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
890
+ const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
891
+ const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
892
+ // final state lands in snapshot slot 0 (most-recent-first ordering)
893
+ float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
894
+
895
+ // Push dummy write-back
896
+ dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
897
+ S_v * sizeof(float), S_v * sizeof(float),
898
+ S_v * sizeof(float), 0);
899
+
900
+ // Push fetch
901
+ dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
902
+ S_v * sizeof(float), S_v * sizeof(float),
903
+ S_v * sizeof(float), S_v);
904
+
905
+ ir_prefetch += nth;
906
+ spad_idx ^= 1;
907
+ }
908
+
909
+ int curr_spad_idx = 0;
910
+ for (uint32_t ir = ith; ir < total_rows; ir += nth) {
911
+ dma_queue_pop(dma);
912
+ dma_queue_pop(dma);
913
+
914
+ float * s_work_curr = s_work[curr_spad_idx];
915
+
916
+ const uint32_t iv1 = fastmodulo(ir, H, &fd_H);
917
+ const uint32_t iv3 = fastdiv(ir, &fd_H);
918
+
919
+ const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1);
920
+ const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1);
921
+ const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
922
+ const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
923
+
924
+ // final state lands in snapshot slot 0 (most-recent-first ordering)
925
+ float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
926
+
927
+ float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
928
+
929
+ const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
930
+ (uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]);
931
+ const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
932
+ (uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]);
933
+ const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
934
+ (uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]);
935
+ const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
936
+ (uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]);
937
+ const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
938
+ (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]);
939
+
940
+ hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v);
941
+ hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v);
942
+
943
+ if (kda) {
944
+ hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
945
+
946
+ uint32_t j = 0;
947
+ for (; j + 8 <= S_v; j += 8) {
948
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
949
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
950
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
951
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
952
+ float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
953
+ float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
954
+ float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
955
+ float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
956
+ gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
957
+ local_gate, local_k, S_v, local_sums);
958
+
959
+ float local_delta_b[32] __attribute__((aligned(128)));
960
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
961
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
962
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
963
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
964
+
965
+ gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
966
+ local_k, local_delta_b, local_q, S_v, local_sums);
967
+
968
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
969
+ hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
970
+ }
971
+ for (; j + 4 <= S_v; j += 4) {
972
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
973
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
974
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
975
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
976
+ gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
977
+
978
+ float local_delta_b[32] __attribute__((aligned(128)));
979
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
980
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
981
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
982
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
983
+
984
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
985
+
986
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
987
+ hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
988
+ }
989
+ HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
990
+ for (; j < S_v; ++j) {
991
+ float * row = s_work_curr + (uint64_t) j * S_v;
992
+ HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
993
+ HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
994
+ HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
995
+ HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
996
+ attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
997
+ }
998
+ } else {
999
+ const float gate = expf(g_t[0]);
1000
+ uint32_t j = 0;
1001
+ for (; j + 8 <= S_v; j += 8) {
1002
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
1003
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
1004
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
1005
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
1006
+ float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v;
1007
+ float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v;
1008
+ float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v;
1009
+ float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v;
1010
+ gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
1011
+ gate, local_k, S_v, local_sums);
1012
+
1013
+ float local_delta_b[32] __attribute__((aligned(128)));
1014
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
1015
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
1016
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
1017
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
1018
+
1019
+ gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
1020
+ local_k, local_delta_b, local_q, S_v, local_sums);
1021
+
1022
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
1023
+ hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn);
1024
+ }
1025
+ for (; j + 4 <= S_v; j += 4) {
1026
+ float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v;
1027
+ float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v;
1028
+ float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v;
1029
+ float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v;
1030
+ gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
1031
+
1032
+ float local_delta_b[32] __attribute__((aligned(128)));
1033
+ HVX_Vector vv_t = hvx_vmemu(v_t + j);
1034
+ HVX_Vector v_local_sums = hvx_vmem(local_sums);
1035
+ HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums);
1036
+ hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val));
1037
+
1038
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
1039
+
1040
+ HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale));
1041
+ hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn);
1042
+ }
1043
+ HVX_Vector vscale_splat = hvx_vec_splat_f32(scale);
1044
+ for (; j < S_v; ++j) {
1045
+ float * row = s_work_curr + (uint64_t) j * S_v;
1046
+ HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
1047
+ HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]);
1048
+ HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val));
1049
+ HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v);
1050
+ attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat));
1051
+ }
1052
+ }
1053
+
1054
+ // Push real write-back
1055
+ dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr),
1056
+ S_v * sizeof(float), S_v * sizeof(float),
1057
+ S_v * sizeof(float), S_v);
1058
+
1059
+ // Prefetch next block (if any)
1060
+ if (ir_prefetch < total_rows) {
1061
+ const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
1062
+ const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
1063
+ const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
1064
+
1065
+ dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in),
1066
+ S_v * sizeof(float), S_v * sizeof(float),
1067
+ S_v * sizeof(float), S_v);
1068
+
1069
+ ir_prefetch += nth;
1070
+ spad_idx ^= 1;
1071
+ }
1072
+
1073
+ curr_spad_idx ^= 1;
1074
+ }
1075
+ dma_queue_flush(dma);
1076
+ }
1077
+
1078
+
1079
+ int op_gated_delta_net(struct htp_ops_context * octx) {
1080
+ const struct htp_tensor * q = octx->src[0];
1081
+ const struct htp_tensor * k = octx->src[1];
1082
+ const struct htp_tensor * v = octx->src[2];
1083
+ const struct htp_tensor * g = octx->src[3];
1084
+ const struct htp_tensor * beta = octx->src[4];
1085
+ const struct htp_tensor * state = octx->src[5];
1086
+ const struct htp_tensor * dst = octx->dst;
1087
+
1088
+ if (!q || !k || !v || !g || !beta || !state || !dst) {
1089
+ return HTP_STATUS_INVAL_PARAMS;
1090
+ }
1091
+
1092
+ if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 ||
1093
+ g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 ||
1094
+ dst->type != HTP_TYPE_F32) {
1095
+ return HTP_STATUS_NO_SUPPORT;
1096
+ }
1097
+
1098
+ const uint32_t S_v = v->ne[0];
1099
+ const uint32_t H = v->ne[1];
1100
+ const uint32_t n_tokens = v->ne[2];
1101
+ const uint32_t n_seqs = v->ne[3];
1102
+ const uint32_t K = octx->op_params[0];
1103
+
1104
+ if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
1105
+ return HTP_STATUS_NO_SUPPORT;
1106
+ }
1107
+ if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
1108
+ return HTP_STATUS_NO_SUPPORT;
1109
+ }
1110
+ if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 ||
1111
+ q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 ||
1112
+ (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
1113
+ return HTP_STATUS_NO_SUPPORT;
1114
+ }
1115
+ // state holds s0 only: [S_v, S_v, H, n_seqs]
1116
+ if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) {
1117
+ return HTP_STATUS_NO_SUPPORT;
1118
+ }
1119
+ if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
1120
+ return HTP_STATUS_NO_SUPPORT;
1121
+ }
1122
+
1123
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
1124
+ return HTP_STATUS_OK;
1125
+ }
1126
+
1127
+ struct htp_gdn_context gctx;
1128
+ gctx.octx = octx;
1129
+ gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads;
1130
+ gctx.state_bytes = (size_t) S_v * S_v * sizeof(float);
1131
+
1132
+ size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
1133
+ state_aligned = (state_aligned + 127) & ~(size_t)127;
1134
+
1135
+ assert(octx->ctx->vtcm_base != NULL);
1136
+ assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads);
1137
+
1138
+ gctx.vtcm_base = octx->ctx->vtcm_base;
1139
+ gctx.vtcm_per_thread = 2 * state_aligned;
1140
+
1141
+ if (n_tokens == 1) {
1142
+ worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads);
1143
+ } else {
1144
+ worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads);
1145
+ }
1146
+
1147
+ return HTP_STATUS_OK;
1148
+ }