whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -2,28 +2,82 @@
2
2
  #pragma clang diagnostic ignored "-Wunused-function"
3
3
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
4
 
5
- #ifdef HTP_DEBUG
6
- # define FARF_HIGH 1
7
- #endif
8
-
9
5
  #include <HAP_farf.h>
10
- #include <HAP_mem.h>
11
6
  #include <HAP_perf.h>
12
- #include <HAP_ps.h>
13
- #include <hexagon_protos.h>
14
- #include <hexagon_types.h>
7
+
15
8
  #include <math.h>
16
- #include <qurt_thread.h>
17
9
  #include <string.h>
18
10
 
11
+ #include "hex-dma.h"
12
+ #include "hvx-exp.h"
13
+ #include "hvx-sigmoid.h"
14
+ #include "hvx-utils.h"
15
+
19
16
  #define GGML_COMMON_DECL_C
20
17
  #include "ggml-common.h"
21
18
  #include "htp-ctx.h"
22
- #include "htp-dma.h"
23
- #include "htp-msg.h"
24
19
  #include "htp-ops.h"
25
- #include "hvx-utils.h"
26
- #include "ops-utils.h"
20
+
21
+ struct htp_unary_context {
22
+ struct htp_ops_context * octx;
23
+
24
+ // Precomputed values
25
+ const uint8_t * data_src0;
26
+ const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL
27
+ uint8_t * data_dst;
28
+
29
+ size_t src0_data_row_size; // actual data bytes per row
30
+ size_t src1_data_row_size;
31
+ size_t dst_data_row_size; // actual data bytes per row
32
+
33
+ size_t src0_row_size_aligned;
34
+ size_t src1_row_size_aligned;
35
+ size_t dst_row_size_aligned;
36
+
37
+ size_t src0_spad_half_size;
38
+ size_t src1_spad_half_size;
39
+ size_t dst_spad_half_size;
40
+
41
+ uint32_t block;
42
+ uint32_t src0_nrows;
43
+ uint32_t src0_nrows_per_thread;
44
+ uint32_t nc;
45
+ bool broadcast_weight;
46
+ };
47
+
48
+ // Convert flat row index to DDR byte offset using the tensor's actual strides.
49
+ // ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
50
+ static inline size_t unary_row_offset(uint32_t ir,
51
+ uint32_t ne1, uint32_t ne2,
52
+ size_t nb1, size_t nb2, size_t nb3) {
53
+ const uint32_t i1 = ir % ne1;
54
+ const uint32_t i2 = (ir / ne1) % ne2;
55
+ const uint32_t i3 = ir / (ne1 * ne2);
56
+ return i1 * nb1 + i2 * nb2 + i3 * nb3;
57
+ }
58
+ // Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
59
+ // boundary of src and dst so the nb1 stride stays valid for all rows.
60
+ static inline uint32_t unary_block_size(uint32_t ir,
61
+ uint32_t end_row,
62
+ uint32_t block,
63
+ bool src_contig,
64
+ bool dst_contig,
65
+ uint32_t src_ne1,
66
+ uint32_t dst_ne1) {
67
+ uint32_t limit = MIN(block, end_row - ir);
68
+
69
+ if (!src_contig) {
70
+ const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
71
+ limit = MIN(limit, src_slice_end - ir);
72
+ }
73
+
74
+ if (!dst_contig) {
75
+ const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
76
+ limit = MIN(limit, dst_slice_end - ir);
77
+ }
78
+
79
+ return limit;
80
+ }
27
81
 
28
82
  #define htp_unary_preamble \
29
83
  const uint32_t ne00 = src->ne[0]; \
@@ -51,110 +105,578 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
51
105
  uint8_t * restrict pad,
52
106
  const int num_elems,
53
107
  float epsilon) {
108
+ (void)pad;
109
+
54
110
  const HVX_Vector * restrict v_src = (HVX_Vector *) src;
55
111
  HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
56
112
 
57
- HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
58
- HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
113
+ const int nvec = num_elems / VLEN_FP32; // number of full vectors
114
+ const int nloe = num_elems % VLEN_FP32; // leftover elements
115
+
116
+ // Compute sum of squares for full vectors
117
+ HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
118
+ HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
119
+
120
+ #pragma unroll(4)
121
+ for (int i = 0; i < nvec; i++) {
122
+ HVX_Vector v1 = v_src[i];
123
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
124
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
125
+ }
126
+
127
+ // Handle tail elements using vectorized ops with masking
128
+ if (nloe > 0) {
129
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
130
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
131
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
132
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
133
+ }
134
+
135
+ // Reduce HVX sum
136
+ sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
137
+
138
+ HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
139
+ HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
140
+ HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
141
+ HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
142
+
143
+ // Scale full vectors
144
+ HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
59
145
 
60
- int step_of_1 = num_elems >> 5;
61
146
  #pragma unroll(4)
62
- for (int i = 0; i < step_of_1; i++) {
147
+ for (int i = 0; i < nvec; i++) {
148
+ HVX_Vector v1 = v_src[i];
149
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
150
+ v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
151
+ }
152
+
153
+ // Handle tail elements using vectorized ops with masking
154
+ if (nloe > 0) {
155
+
156
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
157
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
158
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
159
+ HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
160
+
161
+ // Store with masking to avoid overwriting memory beyond the tensor
162
+ hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
163
+ }
164
+ }
165
+
166
+ static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src,
167
+ const uint8_t * restrict weight,
168
+ uint8_t * restrict dst,
169
+ const int num_elems,
170
+ float epsilon) {
171
+ const HVX_Vector * restrict v_src = (const HVX_Vector *) src;
172
+ const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight;
173
+ HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
174
+
175
+ const int nvec = num_elems / VLEN_FP32; // number of full vectors
176
+ const int nloe = num_elems % VLEN_FP32; // leftover elements
177
+
178
+ // Compute sum of squares for full vectors
179
+ HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
180
+ HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
181
+
182
+ #pragma unroll(4)
183
+ for (int i = 0; i < nvec; i++) {
63
184
  HVX_Vector v1 = v_src[i];
64
185
  HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
65
- sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
186
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
66
187
  }
67
188
 
68
- HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
69
- sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
189
+ // Handle tail elements using vectorized ops with masking
190
+ if (nloe > 0) {
191
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
192
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
193
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
194
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
195
+ }
70
196
 
71
- HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems);
72
- HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v);
197
+ // Reduce HVX sum
198
+ sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
199
+
200
+ HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
201
+ HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
73
202
  HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
74
203
  HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
75
204
 
76
- HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
205
+ // Scale and multiply
206
+ HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
77
207
 
78
208
  #pragma unroll(4)
79
- for (int i = 0; i < step_of_1; i++) {
209
+ for (int i = 0; i < nvec; i++) {
80
210
  HVX_Vector v1 = v_src[i];
81
211
  HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
82
- v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
212
+ HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2);
213
+ HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]);
214
+ v_dst[i] = Q6_Vsf_equals_Vqf32(result);
215
+ }
216
+
217
+ // Handle tail elements using vectorized ops with masking
218
+ if (nloe > 0) {
219
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
220
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
221
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
222
+ HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2);
223
+ HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]);
224
+ HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result);
225
+
226
+ // Store with masking to avoid overwriting memory beyond the tensor
227
+ hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v);
83
228
  }
84
229
  }
85
230
 
86
- static void scale_htp_f32(const float * restrict src,
87
- float * restrict dst,
88
- uint8_t * restrict spad,
89
- const uint32_t num_rows,
90
- const uint32_t row_elems,
91
- const size_t row_size,
92
- int32_t * op_params,
93
- int opt_path) {
231
+ static void hvx_fast_norm_f32(const uint8_t * restrict src,
232
+ uint8_t * restrict dst,
233
+ uint8_t * restrict pad,
234
+ const int num_elems,
235
+ float epsilon) {
236
+ (void)pad;
237
+
238
+ const HVX_Vector * restrict v_src = (HVX_Vector *) src;
239
+ HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
240
+
241
+ const int nvec = num_elems / VLEN_FP32; // number of full vectors
242
+ const int nloe = num_elems % VLEN_FP32; // leftover elements
243
+
244
+ // Compute sum of squares and sum of values for full vectors
245
+ HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
246
+ HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
247
+ HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
248
+
249
+ #pragma unroll(4)
250
+ for (int i = 0; i < nvec; i++) {
251
+ HVX_Vector v1 = v_src[i];
252
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
253
+ sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
254
+ sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
255
+ }
256
+
257
+ // Handle tail elements using vectorized ops with masking
258
+ if (nloe > 0) {
259
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
260
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
261
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
262
+ sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
263
+ sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
264
+ }
265
+
266
+ // Reduce HVX sums
267
+ sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
268
+ sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
269
+
270
+ HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
271
+ HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
272
+ HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
273
+ HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
274
+ HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
275
+ HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
276
+ HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
277
+
278
+ // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
279
+ HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
280
+ HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v));
281
+
282
+ #pragma unroll(4)
283
+ for (int i = 0; i < nvec; i++) {
284
+ HVX_Vector v1 = v_src[i];
285
+ HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
286
+ HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
287
+ v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
288
+ }
289
+
290
+ // Handle tail elements using vectorized ops with masking
291
+ if (nloe > 0) {
292
+
293
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
294
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
295
+ HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
296
+ HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
297
+ HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
298
+
299
+ // Store with masking to avoid overwriting memory beyond the tensor
300
+ hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
301
+ }
302
+ }
303
+
304
+ static void scale_f32(const float * restrict src,
305
+ float * restrict dst,
306
+ uint8_t * restrict spad,
307
+ const uint32_t num_rows,
308
+ const uint32_t row_elems,
309
+ const size_t row_size,
310
+ int32_t * op_params) {
94
311
  float scale = 0.f;
95
312
  float bias = 0.f;
96
313
  memcpy(&scale, &op_params[0], sizeof(float));
97
314
  memcpy(&bias, &op_params[1], sizeof(float));
98
315
 
99
316
  for (uint32_t ir = 0; ir < num_rows; ir++) {
100
- const float * restrict src_local = src + (ir * row_elems);
101
- float * restrict dst_local = dst + (ir * row_elems);
317
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
318
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
102
319
 
103
- if (ir + 1 < num_rows) {
104
- htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
105
- }
320
+ hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
321
+ }
322
+ }
323
+
324
+ static void rms_norm_f32(const float * restrict src,
325
+ float * restrict dst,
326
+ uint8_t * restrict spad,
327
+ const uint32_t num_rows,
328
+ const uint32_t row_elems,
329
+ const size_t row_size,
330
+ int32_t * op_params) {
331
+ float epsilon = 0.f;
332
+ memcpy(&epsilon, op_params, sizeof(float));
106
333
 
107
- hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
334
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
335
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
336
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
337
+
338
+ hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
108
339
  }
109
340
  }
110
341
 
111
- static void rms_norm_htp_f32(const float * restrict src,
342
+ static void rms_norm_mul_f32(const float * restrict src,
343
+ const float * restrict weight,
112
344
  float * restrict dst,
113
- uint8_t * restrict spad,
114
345
  const uint32_t num_rows,
115
346
  const uint32_t row_elems,
116
347
  const size_t row_size,
348
+ const size_t weight_row_size,
117
349
  int32_t * op_params,
118
- int opt_path) {
350
+ bool broadcast_weight) {
119
351
  float epsilon = 0.f;
120
352
  memcpy(&epsilon, op_params, sizeof(float));
121
353
 
122
354
  for (uint32_t ir = 0; ir < num_rows; ir++) {
123
- const float * restrict src_local = src + (ir * row_elems);
124
- float * restrict dst_local = dst + (ir * row_elems);
355
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
356
+ const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size);
357
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
358
+
359
+ hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon);
360
+ }
361
+ }
362
+
363
+ static void norm_f32(const float * restrict src,
364
+ float * restrict dst,
365
+ uint8_t * restrict spad,
366
+ const uint32_t num_rows,
367
+ const uint32_t row_elems,
368
+ const size_t row_size,
369
+ int32_t * op_params) {
370
+ float epsilon = 0.f;
371
+ memcpy(&epsilon, op_params, sizeof(float));
372
+
373
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
374
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
375
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
376
+
377
+ hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
378
+ }
379
+ }
380
+
381
+ static void sqr_f32(const float * restrict src,
382
+ float * restrict dst,
383
+ uint8_t * restrict spad,
384
+ const uint32_t num_rows,
385
+ const uint32_t row_elems,
386
+ const size_t row_size,
387
+ int32_t * op_params) {
388
+
389
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
390
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
391
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
392
+
393
+ hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
394
+ }
395
+ }
396
+
397
+ static void sqrt_f32(const float * restrict src,
398
+ float * restrict dst,
399
+ uint8_t * restrict spad,
400
+ const uint32_t num_rows,
401
+ const uint32_t row_elems,
402
+ const size_t row_size,
403
+ int32_t * op_params) {
404
+
405
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
406
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
407
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
408
+
409
+ hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
410
+ }
411
+ }
412
+
413
+ static void neg_f32(const float * restrict src,
414
+ float * restrict dst,
415
+ uint8_t * restrict spad,
416
+ const uint32_t num_rows,
417
+ const uint32_t row_elems,
418
+ const size_t row_size,
419
+ int32_t * op_params) {
420
+
421
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
422
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
423
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
424
+
425
+ hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
426
+ }
427
+ }
428
+
429
+ static void exp_f32(const float * restrict src,
430
+ float * restrict dst,
431
+ uint8_t * restrict spad,
432
+ const uint32_t num_rows,
433
+ const uint32_t row_elems,
434
+ const size_t row_size,
435
+ int32_t * op_params) {
436
+
437
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
438
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
439
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
440
+
441
+ hvx_exp_f32(dst_local, src_local, row_elems, false);
442
+ }
443
+ }
444
+
445
+ static void sigmoid_f32(const float * restrict src,
446
+ float * restrict dst,
447
+ uint8_t * restrict spad,
448
+ const uint32_t num_rows,
449
+ const uint32_t row_elems,
450
+ const size_t row_size,
451
+ int32_t * op_params) {
452
+
453
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
454
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
455
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
125
456
 
126
- if (ir + 1 < num_rows) {
127
- htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
457
+ hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
458
+ }
459
+ }
460
+
461
+ static void tri_f32(const float * restrict src,
462
+ float * restrict dst,
463
+ uint8_t * restrict spad,
464
+ const uint32_t num_rows,
465
+ const uint32_t row_elems,
466
+ const size_t row_size,
467
+ int32_t * op_params,
468
+ const uint32_t ir,
469
+ const struct htp_unary_context * uctx) {
470
+
471
+ const int32_t ttype = op_params[0];
472
+ const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
473
+ const uint32_t nvec = row_elems / VLEN_FP32;
474
+ const uint32_t nloe = row_elems % VLEN_FP32;
475
+
476
+ const uint32_t ne01 = uctx->octx->src[0]->ne[1];
477
+
478
+ for (uint32_t b = 0; b < num_rows; b++) {
479
+ const uint32_t abs_row = ir + b;
480
+ const uint32_t i01 = abs_row % ne01;
481
+
482
+ const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
483
+ HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
484
+
485
+ uint32_t boundary;
486
+ int keep_left;
487
+ switch (ttype) {
488
+ case 0: boundary = i01; keep_left = 0; break; // keep col >= row
489
+ case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
490
+ case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
491
+ case 3: boundary = i01; keep_left = 1; break; // keep col < row
492
+ default: boundary = 0; keep_left = 0; break;
128
493
  }
494
+ if (boundary > row_elems) boundary = row_elems;
129
495
 
130
- if (1 == opt_path) {
131
- hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
132
- } else {
133
- float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
496
+ // Full HVX vectors — each starts at a 128-byte aligned offset
497
+ for (uint32_t i = 0; i < nvec; i++) {
498
+ const uint32_t vec_start = i * VLEN_FP32;
499
+ const uint32_t vec_end = vec_start + VLEN_FP32;
500
+ if (keep_left) {
501
+ if (vec_end <= boundary) {
502
+ v_dst[i] = v_src[i];
503
+ } else if (vec_start >= boundary) {
504
+ v_dst[i] = zero;
505
+ } else {
506
+ HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
507
+ v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
508
+ }
509
+ } else {
510
+ if (vec_end <= boundary) {
511
+ v_dst[i] = zero;
512
+ } else if (vec_start >= boundary) {
513
+ v_dst[i] = v_src[i];
514
+ } else {
515
+ HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
516
+ v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
517
+ }
518
+ }
519
+ }
520
+
521
+ // Tail elements (row_elems not a multiple of VLEN_FP32)
522
+ if (nloe > 0) {
523
+ const uint32_t vec_start = nvec * VLEN_FP32;
524
+ const uint32_t vec_end = vec_start + nloe;
525
+ HVX_Vector tail_val;
526
+ if (keep_left) {
527
+ if (vec_end <= boundary) {
528
+ tail_val = v_src[nvec];
529
+ } else if (vec_start >= boundary) {
530
+ tail_val = zero;
531
+ } else {
532
+ HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
533
+ tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
534
+ }
535
+ } else {
536
+ if (vec_end <= boundary) {
537
+ tail_val = zero;
538
+ } else if (vec_start >= boundary) {
539
+ tail_val = v_src[nvec];
540
+ } else {
541
+ HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
542
+ tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
543
+ }
544
+ }
545
+ hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
546
+ }
547
+ }
548
+ }
134
549
 
135
- const float mean = sum / row_elems;
136
- const float scale = 1.0f / sqrtf(mean + epsilon);
550
+ static void softplus_f32(const float * restrict src,
551
+ float * restrict dst,
552
+ uint8_t * restrict spad,
553
+ const uint32_t num_rows,
554
+ const uint32_t row_elems,
555
+ const size_t row_size,
556
+ int32_t * op_params) {
557
+ // softplus(x) = log(1 + exp(x))
558
+ // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
559
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
560
+ const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
561
+ float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
137
562
 
138
- hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
563
+ for (uint32_t i = 0; i < row_elems; i++) {
564
+ float x = src_f[i];
565
+ // For x > 20: softplus(x) ≈ x (avoids exp overflow)
566
+ dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
139
567
  }
140
568
  }
141
569
  }
142
570
 
143
- static void unary_job_f32_per_thread(const struct htp_tensor * src,
144
- struct htp_tensor * dst,
145
- uint8_t * spad,
146
- int htp_op,
147
- int32_t * op_params,
148
- uint32_t nth,
149
- uint32_t ith,
150
- uint32_t src0_nrows_per_thread) {
571
+ // --- L2_NORM HVX kernel ---
572
+ // Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
573
+ // scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
574
+ // using rsqrt + inverse to avoid scalar extraction.
575
+ static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
576
+ uint8_t * restrict dst,
577
+ uint8_t * restrict pad,
578
+ const int num_elems,
579
+ float epsilon) {
580
+ (void)pad;
581
+
582
+ const HVX_Vector * restrict v_src = (HVX_Vector *) src;
583
+ HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
584
+
585
+ HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
586
+
587
+ const int nvec = num_elems / VLEN_FP32;
588
+ const int nloe = num_elems % VLEN_FP32;
589
+
590
+ #pragma unroll(4)
591
+ for (int i = 0; i < nvec; i++) {
592
+ HVX_Vector v1 = v_src[i];
593
+ HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
594
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
595
+ }
596
+
597
+ // Include tail elements in the sum-of-squares using a predicate mask
598
+ if (nloe > 0) {
599
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
600
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
601
+ HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
602
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
603
+ }
604
+
605
+ // Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
606
+ // hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
607
+ HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
608
+ HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
609
+ HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
610
+ HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
611
+ HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
612
+ HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
613
+
614
+ #pragma unroll(4)
615
+ for (int i = 0; i < nvec; i++) {
616
+ HVX_Vector v1 = v_src[i];
617
+ v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
618
+ }
619
+
620
+ if (nloe > 0) {
621
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
622
+ HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
623
+ HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
624
+ hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
625
+ }
626
+ }
627
+
628
+ static void l2_norm_f32(const float * restrict src,
629
+ float * restrict dst,
630
+ uint8_t * restrict spad,
631
+ const uint32_t num_rows,
632
+ const uint32_t row_elems,
633
+ const size_t row_size,
634
+ int32_t * op_params) {
635
+ float epsilon = 0.f;
636
+ memcpy(&epsilon, op_params, sizeof(float));
637
+
638
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
639
+ const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
640
+ float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
641
+
642
+ hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
643
+ }
644
+ }
645
+
646
+ static void tanh_f32(const float * restrict src,
647
+ float * restrict dst,
648
+ uint8_t * restrict spad,
649
+ const uint32_t num_rows,
650
+ const uint32_t row_elems,
651
+ const size_t row_size,
652
+ int32_t * op_params) {
653
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
654
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
655
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
656
+
657
+ hvx_tanh_f32_aa(dst_local, src_local, row_elems);
658
+ }
659
+ }
660
+
661
+ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
662
+ const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
663
+ struct htp_ops_context * octx = uctx->octx;
664
+ const struct htp_tensor * src = octx->src[0];
665
+ const struct htp_tensor * dst = octx->dst;
666
+
151
667
  htp_unary_preamble;
152
668
 
153
- const size_t src0_row_size = nb01;
154
- const size_t dst_row_size = nb1;
669
+ int htp_op = octx->op;
670
+ int32_t * op_params = octx->op_params;
671
+ uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
155
672
 
156
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
673
+ const size_t src0_data_row_size = uctx->src0_data_row_size;
674
+ const size_t dst_data_row_size = uctx->dst_data_row_size;
157
675
 
676
+ const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
677
+ const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
678
+
679
+ const uint32_t src0_nrows = uctx->src0_nrows;
158
680
  const uint32_t src0_start_row = src0_nrows_per_thread * ith;
159
681
  const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
160
682
 
@@ -166,66 +688,212 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
166
688
  uint64_t t1, t2;
167
689
  t1 = HAP_perf_get_qtimer_count();
168
690
 
169
- int is_aligned = 1;
170
- int opt_path = 0;
171
- if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
172
- is_aligned = 0;
173
- FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
691
+ const uint8_t * restrict data_src = uctx->data_src0;
692
+ const uint8_t * restrict data_src1 = uctx->data_src1;
693
+ uint8_t * restrict data_dst = uctx->data_dst;
694
+
695
+ const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL;
696
+ const uint32_t nb11 = src1 ? src1->nb[1] : 0;
697
+ const uint32_t nb12 = src1 ? src1->nb[2] : 0;
698
+ const uint32_t nb13 = src1 ? src1->nb[3] : 0;
699
+
700
+ uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
701
+ uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
702
+ uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
703
+
704
+ size_t src0_spad_half_size = uctx->src0_spad_half_size;
705
+ size_t src1_spad_half_size = uctx->src1_spad_half_size;
706
+ size_t dst_spad_half_size = uctx->dst_spad_half_size;
707
+
708
+ // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
709
+ // 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
710
+ // transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
711
+ const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
712
+ (nb03 == (size_t)ne02 * nb02);
713
+ const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
714
+ (nb3 == (size_t)ne2 * nb2);
715
+ const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
716
+ const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
717
+ const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
718
+ if (BLOCK == 0) {
719
+ FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
720
+ octx->src0_spad.size_per_thread, src0_row_size_aligned);
721
+ return;
174
722
  }
175
- if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
176
- opt_path = 1;
723
+
724
+ dma_queue * dma_queue = octx->ctx->dma[ith];
725
+
726
+ // If weight is broadcasted, load it once per thread at the beginning of execution
727
+ if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) {
728
+ dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1);
729
+ dma_queue_flush(dma_queue);
177
730
  }
178
731
 
179
- const uint8_t * restrict data_src = (const uint8_t *) src->data;
180
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
732
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
733
+ const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
181
734
 
182
- const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
183
- float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
184
- uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
735
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
736
+ dma_queue_push(dma_queue,
737
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
738
+ nb1, dst_row_size_aligned, dst_data_row_size, 0);
185
739
 
186
- switch (htp_op) {
187
- case HTP_OP_RMS_NORM:
188
- rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
189
- break;
190
- case HTP_OP_SCALE:
191
- scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
192
- break;
740
+ const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
741
+ dma_queue_push(dma_queue,
742
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
743
+ src0_row_size_aligned, nb01, src0_data_row_size, block_size);
193
744
 
194
- default:
195
- break;
745
+ if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
746
+ const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13);
747
+ dma_queue_push(dma_queue,
748
+ dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off),
749
+ uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size);
750
+ }
751
+
752
+ ir += block_size;
196
753
  }
197
754
 
755
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
756
+ const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
757
+
758
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
759
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
760
+ float * src1_spad = NULL;
761
+ if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
762
+ src1_spad = (float *) dma_queue_pop(dma_queue).dst;
763
+ }
764
+
765
+ // Process block in VTCM
766
+ switch (htp_op) {
767
+ case HTP_OP_NORM:
768
+ norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
769
+ break;
770
+ case HTP_OP_RMS_NORM:
771
+ rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
772
+ break;
773
+ case HTP_OP_RMS_NORM_MUL:
774
+ {
775
+ const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad;
776
+ rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight);
777
+ }
778
+ break;
779
+ case HTP_OP_SCALE:
780
+ scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
781
+ break;
782
+ case HTP_OP_SQR:
783
+ sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
784
+ break;
785
+ case HTP_OP_SQRT:
786
+ sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
787
+ break;
788
+ case HTP_OP_UNARY_NEG:
789
+ neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
790
+ break;
791
+ case HTP_OP_UNARY_EXP:
792
+ exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
793
+ break;
794
+ case HTP_OP_UNARY_SIGMOID:
795
+ sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
796
+ break;
797
+ case HTP_OP_UNARY_SOFTPLUS:
798
+ softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
799
+ break;
800
+ case HTP_OP_UNARY_TANH:
801
+ tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
802
+ break;
803
+ case HTP_OP_L2_NORM:
804
+ l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
805
+ break;
806
+ case HTP_OP_TRI:
807
+ tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
808
+ break;
809
+ default:
810
+ break;
811
+ }
812
+
813
+ const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
814
+ dma_queue_push(dma_queue,
815
+ dma_make_ptr(data_dst + dst_off, dst_spad),
816
+ nb1, dst_row_size_aligned, dst_data_row_size, block_size);
817
+
818
+ // prefetch N+2 loop iteration if any
819
+ const uint32_t next_ir = ir + block_size;
820
+ if (next_ir < src0_end_row) {
821
+ const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
822
+ const uint32_t pref_ir = next_ir + next_block_size;
823
+ if (pref_ir < src0_end_row) {
824
+ const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
825
+ const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
826
+ dma_queue_push(dma_queue,
827
+ dma_make_ptr(src0_spad, data_src + src0_pref_off),
828
+ src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
829
+
830
+ if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
831
+ const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13);
832
+ dma_queue_push(dma_queue,
833
+ dma_make_ptr(src1_spad, data_src1 + src1_pref_off),
834
+ uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size);
835
+ }
836
+ }
837
+ }
838
+ ir += block_size;
839
+ }
840
+
841
+ dma_queue_flush(dma_queue);
842
+
198
843
  t2 = HAP_perf_get_qtimer_count();
199
844
 
200
- FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
845
+ FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
201
846
  src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
202
847
  dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
203
848
  }
204
849
 
205
- static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
206
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
207
-
208
- unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
209
- octx->src0_nrows_per_thread);
210
- }
211
-
212
850
  static int execute_op_unary_f32(struct htp_ops_context * octx) {
213
851
  int err = HTP_STATUS_OK;
214
852
 
215
- const struct htp_tensor * src0 = &octx->src0;
216
- struct htp_tensor * dst = &octx->dst;
853
+ const struct htp_tensor * src0 = octx->src[0];
854
+ const struct htp_tensor * dst = octx->dst;
217
855
 
218
- worker_callback_t unary_op_func;
219
- const char * op_type = NULL;
856
+ const char * op_type = NULL;
220
857
 
221
858
  switch (octx->op) {
859
+ case HTP_OP_NORM:
860
+ op_type = "norm-f32";
861
+ break;
222
862
  case HTP_OP_RMS_NORM:
223
- unary_op_func = unary_job_dispatcher_f32;
224
- op_type = "rmsnorm-f32";
863
+ op_type = "rmsnorm-f32";
864
+ break;
865
+ case HTP_OP_RMS_NORM_MUL:
866
+ op_type = "rmsnorm-mul-f32";
225
867
  break;
226
868
  case HTP_OP_SCALE:
227
- unary_op_func = unary_job_dispatcher_f32;
228
- op_type = "scale-f32";
869
+ op_type = "scale-f32";
870
+ break;
871
+ case HTP_OP_SQR:
872
+ op_type = "sqr-f32";
873
+ break;
874
+ case HTP_OP_SQRT:
875
+ op_type = "sqrt-f32";
876
+ break;
877
+ case HTP_OP_UNARY_NEG:
878
+ op_type = "neg-f32";
879
+ break;
880
+ case HTP_OP_UNARY_EXP:
881
+ op_type = "exp-f32";
882
+ break;
883
+ case HTP_OP_UNARY_SIGMOID:
884
+ op_type = "sigmoid-f32";
885
+ break;
886
+ case HTP_OP_UNARY_SOFTPLUS:
887
+ op_type = "softplus-f32";
888
+ break;
889
+ case HTP_OP_UNARY_TANH:
890
+ op_type = "tanh-f32";
891
+ break;
892
+ case HTP_OP_L2_NORM:
893
+ op_type = "l2norm-f32";
894
+ break;
895
+ case HTP_OP_TRI:
896
+ op_type = "tri-f32";
229
897
  break;
230
898
 
231
899
  default:
@@ -233,38 +901,139 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
233
901
  return HTP_STATUS_NO_SUPPORT;
234
902
  }
235
903
 
236
- const int n_threads = octx->n_threads;
237
904
  const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
905
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
238
906
 
239
- const size_t src0_row_size = src0->nb[1];
240
- const size_t dst_row_size = dst->nb[1];
907
+ const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
908
+ const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
909
+
910
+ const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
911
+ const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
912
+
913
+ size_t src1_data_row_size = 0;
914
+ size_t src1_row_size_aligned = 0;
915
+ bool broadcast_weight = false;
916
+ const struct htp_tensor * src1 = NULL;
917
+
918
+ if (octx->op == HTP_OP_RMS_NORM_MUL) {
919
+ src1 = octx->src[1];
920
+ src1_data_row_size = src1->ne[0] * sizeof(float);
921
+ src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN);
922
+ broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1);
923
+ }
241
924
 
242
925
  // VTCM scratchpads for all tensors
243
- octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
244
- octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
926
+ // N rows per thread, padded to HVX vector size
927
+ // Double buffering requires 2x size per buffer
245
928
 
246
- size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
929
+ size_t spad_size_per_row = 0;
930
+ size_t vtcm_row_per_thread = 0;
247
931
 
248
- FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
249
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
250
- octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
932
+ if (octx->op == HTP_OP_RMS_NORM_MUL) {
933
+ if (broadcast_weight) {
934
+ size_t available_vtcm = octx->ctx->vtcm_size;
935
+ size_t src1_spad_total = n_threads * src1_row_size_aligned;
936
+ if (available_vtcm > src1_spad_total) {
937
+ available_vtcm -= src1_spad_total;
938
+ } else {
939
+ available_vtcm = 0;
940
+ }
941
+ spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
942
+ vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row);
943
+ } else {
944
+ spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned);
945
+ vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row);
946
+ }
947
+ } else {
948
+ spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
949
+ vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
950
+ }
251
951
 
252
952
  // Make sure the reserved vtcm size is sufficient
253
- if (octx->ctx->vtcm_size < spad_size) {
953
+ if (vtcm_row_per_thread == 0) {
254
954
  FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
255
- spad_size);
955
+ spad_size_per_row * n_threads);
256
956
  return HTP_STATUS_VTCM_TOO_SMALL;
257
957
  }
258
958
 
959
+ octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
960
+ octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
961
+
962
+ octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
963
+ octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
964
+
965
+ if (octx->op == HTP_OP_RMS_NORM_MUL) {
966
+ if (broadcast_weight) {
967
+ octx->src1_spad.size_per_thread = src1_row_size_aligned;
968
+ } else {
969
+ octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2;
970
+ }
971
+ octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
972
+ } else {
973
+ octx->src1_spad.size = 0;
974
+ octx->src1_spad.size_per_thread = 0;
975
+ }
976
+
259
977
  octx->src0_spad.data = octx->ctx->vtcm_base;
260
- octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
978
+ if (octx->op == HTP_OP_RMS_NORM_MUL) {
979
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
980
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
981
+ } else {
982
+ octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
983
+ }
984
+
985
+ octx->src0_spad.src = NULL;
986
+ octx->src1_spad.src = NULL;
987
+ octx->dst_spad.src = NULL;
988
+
989
+ FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
990
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
991
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
261
992
 
262
993
  if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
263
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
994
+ struct htp_unary_context uctx = {
995
+ .octx = octx,
996
+ .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
997
+ .src0_nrows = src0_nrows,
998
+
999
+ .data_src0 = (const uint8_t *)src0->data,
1000
+ .data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL,
1001
+ .data_dst = (uint8_t *)dst->data,
1002
+
1003
+ .src0_data_row_size = src0_data_row_size,
1004
+ .src1_data_row_size = src1_data_row_size,
1005
+ .dst_data_row_size = dst_data_row_size,
264
1006
 
265
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
1007
+ .src0_row_size_aligned = src0_row_size_aligned,
1008
+ .src1_row_size_aligned = src1_row_size_aligned,
1009
+ .dst_row_size_aligned = dst_row_size_aligned,
266
1010
 
267
- worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
1011
+ .src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
1012
+ .src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0,
1013
+ .dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
1014
+
1015
+ .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
1016
+ .nc = src0->ne[0],
1017
+ .broadcast_weight = broadcast_weight,
1018
+ };
1019
+
1020
+ worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
1021
+ }
1022
+
1023
+ return err;
1024
+ }
1025
+
1026
+ int op_tri(struct htp_ops_context * octx) {
1027
+ int err = HTP_STATUS_OK;
1028
+
1029
+ switch (octx->src[0]->type) {
1030
+ case HTP_TYPE_F32:
1031
+ err = execute_op_unary_f32(octx);
1032
+ break;
1033
+
1034
+ default:
1035
+ err = HTP_STATUS_NO_SUPPORT;
1036
+ break;
268
1037
  }
269
1038
 
270
1039
  return err;
@@ -273,7 +1042,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
273
1042
  int op_unary(struct htp_ops_context * octx) {
274
1043
  int err = HTP_STATUS_OK;
275
1044
 
276
- switch (octx->src0.type) {
1045
+ switch (octx->src[0]->type) {
277
1046
  case HTP_TYPE_F32:
278
1047
  err = execute_op_unary_f32(octx);
279
1048
  break;