whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -2,42 +2,51 @@
2
2
  #pragma clang diagnostic ignored "-Wunused-function"
3
3
  #pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
4
 
5
- #ifdef HTP_DEBUG
6
- # define FARF_HIGH 1
7
- #endif
8
-
9
5
  #include <HAP_farf.h>
10
- #include <HAP_mem.h>
11
6
  #include <HAP_perf.h>
12
- #include <HAP_ps.h>
13
- #include <hexagon_protos.h>
14
- #include <hexagon_types.h>
7
+
15
8
  #include <math.h>
16
- #include <qurt_thread.h>
17
9
  #include <string.h>
18
10
 
11
+ #include "hex-dma.h"
12
+ #include "hvx-utils.h"
13
+
19
14
  #define GGML_COMMON_DECL_C
20
15
  #include "ggml-common.h"
21
16
  #include "htp-ctx.h"
22
- #include "htp-dma.h"
23
- #include "htp-msg.h"
24
17
  #include "htp-ops.h"
25
- #include "hvx-utils.h"
26
- #include "ops-utils.h"
18
+ #include "htp-ops.h"
19
+
20
+ #ifndef MIN
21
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
22
+ #endif
23
+
24
+ // Context for binary operations
25
+ struct htp_binary_context {
26
+ struct htp_ops_context * octx;
27
27
 
28
- typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
29
- const uint8_t * src1,
30
- uint8_t * data_dst,
31
- const int num_elems);
28
+ struct fastdiv_values src0_dim1_div; // ne01
29
+ struct fastdiv_values src0_dim2_div; // ne02
30
+ struct fastdiv_values src0_dim12_div;// ne03
32
31
 
33
- static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
34
- static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
32
+ struct fastdiv_values src1_dim1_div; // ne11
33
+ struct fastdiv_values src1_dim2_div; // ne12
34
+ struct fastdiv_values src1_dim3_div; // ne13
35
35
 
36
- #define htp_binary_preamble \
37
- const struct htp_tensor * src0 = &octx->src0; \
38
- const struct htp_tensor * src1 = &octx->src1; \
39
- const struct htp_tensor * src2 = &octx->src2; \
40
- struct htp_tensor * dst = &octx->dst; \
36
+ uint32_t block_max;
37
+ uint32_t nrows_per_thread;
38
+ size_t src0_row_size_aligned;
39
+ size_t src1_row_size_aligned;
40
+ size_t dst_row_size_aligned;
41
+
42
+ bool split_at_ne01;
43
+ bool split_at_ne02;
44
+ };
45
+
46
+ #define htp_binary_preamble \
47
+ const struct htp_tensor * src0 = octx->src[0]; \
48
+ const struct htp_tensor * src1 = octx->src[1]; \
49
+ const struct htp_tensor * dst = octx->dst; \
41
50
  \
42
51
  const uint32_t ne00 = src0->ne[0]; \
43
52
  const uint32_t ne01 = src0->ne[1]; \
@@ -49,312 +58,815 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
49
58
  const uint32_t ne12 = src1->ne[2]; \
50
59
  const uint32_t ne13 = src1->ne[3]; \
51
60
  \
52
- const uint32_t ne0 = dst->ne[0]; \
53
- const uint32_t ne1 = dst->ne[1]; \
54
- const uint32_t ne2 = dst->ne[2]; \
55
- const uint32_t ne3 = dst->ne[3]; \
56
- \
57
- const uint32_t nb00 = src0->nb[0]; \
58
61
  const uint32_t nb01 = src0->nb[1]; \
59
62
  const uint32_t nb02 = src0->nb[2]; \
60
63
  const uint32_t nb03 = src0->nb[3]; \
61
64
  \
62
- const uint32_t nb10 = src1->nb[0]; \
63
65
  const uint32_t nb11 = src1->nb[1]; \
64
66
  const uint32_t nb12 = src1->nb[2]; \
65
67
  const uint32_t nb13 = src1->nb[3]; \
66
68
  \
67
- const uint32_t nb0 = dst->nb[0]; \
68
69
  const uint32_t nb1 = dst->nb[1]; \
69
70
  const uint32_t nb2 = dst->nb[2]; \
70
- const uint32_t nb3 = dst->nb[3]; \
71
- \
72
- const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
71
+ const uint32_t nb3 = dst->nb[3];
73
72
 
74
- static void binary_job_f32_per_thread(struct htp_ops_context * octx,
75
- uint8_t * spad_data,
76
- uint32_t nth,
77
- uint32_t ith,
78
- enum htp_op op) {
79
- htp_binary_preamble;
73
+ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) {
74
+ uint32_t i03, i02, i01, rem;
75
+ i03 = fastdiv(ir, &bctx->src0_dim12_div);
76
+ rem = ir - i03 * (ne02 * ne01);
77
+ i02 = fastdiv(rem, &bctx->src0_dim1_div);
78
+ i01 = rem - i02 * ne01;
80
79
 
81
- const size_t src0_row_size = nb01;
82
- const size_t src1_row_size = nb11;
83
- const size_t dst_row_size = nb1;
80
+ uint32_t rows_left = end_row - ir;
81
+ uint32_t block_limit = rows_left;
84
82
 
85
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
86
- const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
83
+ if (bctx->split_at_ne01) {
84
+ block_limit = MIN(block_limit, ne01 - i01);
85
+ }
86
+ if (bctx->split_at_ne02) {
87
+ uint32_t rows_in_plane = (ne02 * ne01) - rem;
88
+ block_limit = MIN(block_limit, rows_in_plane);
89
+ }
87
90
 
88
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
89
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
91
+ return MIN(bctx->block_max, block_limit);
92
+ }
90
93
 
91
- // no work for this thread
92
- if (src0_start_row >= src0_end_row) {
93
- return;
94
+ // Macro for scalar op switch
95
+ #define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
96
+ if(TYPE == HTP_TYPE_F32) { \
97
+ switch (octx->op) { \
98
+ case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
99
+ case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
100
+ case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
101
+ case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
102
+ default: break; \
103
+ } \
104
+ } \
105
+ else { \
106
+ switch (octx->op) { \
107
+ case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
108
+ case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
109
+ case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
110
+ case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
111
+ default: break; \
112
+ } \
94
113
  }
95
114
 
96
- uint64_t t1, t2;
97
- t1 = HAP_perf_get_qtimer_count();
98
-
99
- int is_aligned = 1;
100
- int opt_path = 0;
101
- if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
102
- (0 == htp_is_aligned((void *) dst->data, VLEN))) {
103
- FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
104
- is_aligned = 0;
115
+ // Macro for vector op switch (All Aligned)
116
+ #define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
117
+ if(TYPE == HTP_TYPE_F32) { \
118
+ switch (octx->op) { \
119
+ case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
120
+ case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
121
+ case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
122
+ case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
123
+ default: break; \
124
+ } \
125
+ } \
126
+ else { \
127
+ switch (octx->op) { \
128
+ case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
129
+ case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
130
+ case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
131
+ case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
132
+ default: break; \
133
+ } \
105
134
  }
106
- if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
107
- opt_path = 1;
135
+
136
+ // Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
137
+ #define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
138
+ if(TYPE == HTP_TYPE_F32) { \
139
+ switch (octx->op) { \
140
+ case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
141
+ case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
142
+ case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
143
+ case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
144
+ default: break; \
145
+ } \
146
+ } \
147
+ else { \
148
+ switch (octx->op) { \
149
+ case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
150
+ case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
151
+ case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
152
+ case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
153
+ default: break; \
154
+ } \
108
155
  }
109
156
 
110
- hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
157
+ // Macro for vector op switch (All Unaligned - generic loop used in element repeat)
158
+ #define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
159
+ if(TYPE == HTP_TYPE_F32) { \
160
+ switch (octx->op) { \
161
+ case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
162
+ case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
163
+ case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
164
+ case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
165
+ default: break; \
166
+ } \
167
+ } \
168
+ else { \
169
+ switch (octx->op) { \
170
+ case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
171
+ case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
172
+ case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
173
+ case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
174
+ default: break; \
175
+ } \
176
+ }
111
177
 
112
- uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
178
+ // 1. Scalar src1 (ne10 == 1)
179
+ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
180
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
181
+ struct htp_ops_context * octx = bctx->octx;
182
+ htp_binary_preamble;
113
183
 
114
- const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
115
- uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
184
+ const uint32_t src0_type = octx->src[0]->type;
185
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
186
+ const uint32_t total_rows = ne01 * ne02 * ne03;
187
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
188
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
189
+ if (start_row >= end_row) return;
190
+
191
+ FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
192
+
193
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
194
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
195
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
196
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
197
+
198
+ dma_queue * q = octx->ctx->dma[ith];
199
+ uint32_t ir_prefetch = start_row;
200
+ int spad_idx = 0;
201
+
202
+ // Preamble
203
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
204
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
205
+ uint32_t i03, i02, i01, rem;
206
+ i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
207
+ rem = ir_prefetch - i03 * (ne02 * ne01);
208
+ i02 = fastdiv(rem, &bctx->src0_dim1_div);
209
+ i01 = rem - i02 * ne01;
210
+
211
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
212
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
213
+
214
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
215
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
216
+
217
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
218
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
219
+ ir_prefetch += current_block_size;
220
+ spad_idx ^= 1;
221
+ }
116
222
 
117
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
223
+ // Main loop
224
+ for (uint32_t ir = start_row; ir < end_row; ) {
225
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
118
226
 
119
- const uint32_t ne02_ne01 = ne02 * ne01;
227
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
228
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
120
229
 
121
- for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
122
- const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
123
- const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
124
- const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
230
+ uint32_t i03, i02, i01, rem;
231
+ i03 = fastdiv(ir, &bctx->src0_dim12_div);
232
+ rem = ir - i03 * (ne02 * ne01);
233
+ i02 = fastdiv(rem, &bctx->src0_dim1_div);
234
+ i01 = rem - i02 * ne01;
125
235
 
126
- const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
127
- const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
128
- const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
236
+ // src1 indices (broadcast/repeat)
237
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
238
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
239
+ uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
129
240
 
130
- const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
241
+ uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
242
+ uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
131
243
 
132
- if (ir + 1 < src0_end_row) {
133
- htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
134
- if (src1_row_size == src0_row_size) {
135
- htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
136
- }
244
+ for (uint32_t r = 0; r < current_block_size; r++) {
245
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
246
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
247
+ COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
248
+ src1_ptr += s1_stride;
137
249
  }
138
250
 
139
- const uint32_t nr0 = ne00 / ne10;
140
- if (nr0 > 1) {
141
- if ((1 == is_aligned) && (nr0 == ne00)) {
142
- hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
143
- } else {
144
- for (uint32_t r = 0; r < nr0; r++) {
145
- memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
146
- }
147
- }
148
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
149
- } else {
150
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
251
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
252
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
253
+
254
+ if (ir_prefetch < end_row) {
255
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
256
+ uint32_t p03, p02, p01, prem;
257
+ p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
258
+ prem = ir_prefetch - p03 * (ne02 * ne01);
259
+ p02 = fastdiv(prem, &bctx->src0_dim1_div);
260
+ p01 = prem - p02 * ne01;
261
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
262
+
263
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
264
+ ir_prefetch += next_block_size;
151
265
  }
266
+ ir += current_block_size;
267
+ }
268
+ dma_queue_flush(q);
269
+ }
152
270
 
153
- src0_ptr += src0_row_size;
154
- dst_ptr += dst_row_size;
271
+ // 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
272
+ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
273
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
274
+ struct htp_ops_context * octx = bctx->octx;
275
+ htp_binary_preamble;
276
+
277
+ const uint32_t src0_type = octx->src[0]->type;
278
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
279
+ const uint32_t total_rows = ne01 * ne02 * ne03;
280
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
281
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
282
+ if (start_row >= end_row) return;
283
+
284
+ FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
285
+
286
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
287
+ uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
288
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
289
+
290
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
291
+ size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
292
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
293
+
294
+ dma_queue * q = octx->ctx->dma[ith];
295
+ uint32_t ir_prefetch = start_row;
296
+ int spad_idx = 0;
297
+
298
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
299
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
300
+ uint32_t i03, i02, i01, rem;
301
+ i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
302
+ rem = ir_prefetch - i03 * (ne02 * ne01);
303
+ i02 = fastdiv(rem, &bctx->src0_dim1_div);
304
+ i01 = rem - i02 * ne01;
305
+
306
+ uint32_t i13 = (ne13 == 1) ? 0 : i03;
307
+ uint32_t i12 = (ne12 == 1) ? 0 : i02;
308
+ uint32_t i11 = (ne11 == 1) ? 0 : i01;
309
+
310
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
311
+ uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
312
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
313
+
314
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
315
+ uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
316
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
317
+
318
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
319
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
320
+ dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size);
321
+ ir_prefetch += current_block_size;
322
+ spad_idx ^= 1;
155
323
  }
156
324
 
157
- t2 = HAP_perf_get_qtimer_count();
325
+ for (uint32_t ir = start_row; ir < end_row; ) {
326
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
327
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
328
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
329
+ uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
330
+
331
+ for (uint32_t r = 0; r < current_block_size; r++) {
332
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
333
+ uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
334
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
335
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
336
+ }
158
337
 
159
- FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
160
- ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
161
- (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
338
+ uint32_t i03, i02, i01, rem;
339
+ i03 = fastdiv(ir, &bctx->src0_dim12_div);
340
+ rem = ir - i03 * (ne02 * ne01);
341
+ i02 = fastdiv(rem, &bctx->src0_dim1_div);
342
+ i01 = rem - i02 * ne01;
343
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
344
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
345
+
346
+ if (ir_prefetch < end_row) {
347
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
348
+ uint32_t p03, p02, p01, prem;
349
+ p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
350
+ prem = ir_prefetch - p03 * (ne02 * ne01);
351
+ p02 = fastdiv(prem, &bctx->src0_dim1_div);
352
+ p01 = prem - p02 * ne01;
353
+
354
+ uint32_t p13 = (ne13 == 1) ? 0 : p03;
355
+ uint32_t p12 = (ne12 == 1) ? 0 : p02;
356
+ uint32_t p11 = (ne11 == 1) ? 0 : p01;
357
+
358
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
359
+ uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
360
+
361
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
362
+ dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size);
363
+
364
+ ir_prefetch += next_block_size;
365
+ }
366
+ ir += current_block_size;
367
+ }
368
+ dma_queue_flush(q);
162
369
  }
163
370
 
164
- static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
165
- uint8_t * spad_data,
166
- uint32_t nth,
167
- uint32_t ith,
168
- hvx_elemwise_f32_func func_HVX) {
371
+ // 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
372
+ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
373
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
374
+ struct htp_ops_context * octx = bctx->octx;
169
375
  htp_binary_preamble;
170
376
 
171
- const size_t src0_row_size = nb01;
172
- const size_t src1_row_size = nb11;
173
- const size_t dst_row_size = nb1;
377
+ const uint32_t src0_type = octx->src[0]->type;
378
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
379
+ const uint32_t total_rows = ne01 * ne02 * ne03;
380
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
381
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
382
+ if (start_row >= end_row) return;
174
383
 
175
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
384
+ FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
176
385
 
177
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
178
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
386
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
387
+ uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
388
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
179
389
 
180
- // no work for this thread
181
- if (src0_start_row >= src0_end_row) {
182
- return;
183
- }
390
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
391
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
392
+
393
+ dma_queue * q = octx->ctx->dma[ith];
394
+ uint32_t ir_prefetch = start_row;
395
+ int spad_idx = 0;
396
+
397
+ void * s1_ptr = (void *) src1_spad_base;
398
+
399
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
400
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
401
+ uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
402
+ uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
403
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
404
+ uint32_t i01 = rem - i02 * ne01;
405
+
406
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
407
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
184
408
 
185
- uint64_t t1, t2;
186
- t1 = HAP_perf_get_qtimer_count();
409
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
410
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
187
411
 
188
- if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
189
- (0 == htp_is_aligned((void *) dst->data, VLEN))) {
190
- FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
412
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
413
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
414
+ ir_prefetch += current_block_size;
415
+ spad_idx ^= 1;
191
416
  }
192
417
 
193
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
194
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
195
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
196
-
197
- const uint32_t ne02_ne01 = ne02 * ne01;
198
- for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
199
- // src0 indices
200
- const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
201
- const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
202
- const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
203
-
204
- // src1 indices
205
- const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
206
- assert(i11 >= 0 && i11 < ne11);
207
-
208
- float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
209
- const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
210
- const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
211
-
212
- if (ir + 1 < src0_end_row) {
213
- htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
214
- if (src1_row_size == src0_row_size) {
215
- htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
216
- }
418
+ for (uint32_t ir = start_row; ir < end_row; ) {
419
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
420
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
421
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
422
+
423
+ for (uint32_t r = 0; r < current_block_size; r++) {
424
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
425
+ uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
426
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
427
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
217
428
  }
218
429
 
219
- const uint32_t nr0 = ne00 / ne10;
220
- if (nr0 > 1) {
221
- for (uint32_t r = 0; r < nr0; r++) {
222
- memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
223
- }
224
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
225
- } else {
226
- func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
430
+ uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
431
+ uint32_t rem = ir - i03 * (ne02 * ne01);
432
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
433
+ uint32_t i01 = rem - i02 * ne01;
434
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
435
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
436
+
437
+ if (ir_prefetch < end_row) {
438
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
439
+ uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
440
+ uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
441
+ uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
442
+ uint32_t p01 = prem - p02 * ne01;
443
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
444
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
445
+ ir_prefetch += next_block_size;
227
446
  }
447
+ ir += current_block_size;
448
+ }
449
+ dma_queue_flush(q);
450
+ }
451
+
452
+ // 4. Vector Complex (ne10 == ne00, complex broadcast)
453
+ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
454
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
455
+ struct htp_ops_context * octx = bctx->octx;
456
+ htp_binary_preamble;
457
+
458
+ const uint32_t src0_type = octx->src[0]->type;
459
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
460
+ const uint32_t total_rows = ne01 * ne02 * ne03;
461
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
462
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
463
+ if (start_row >= end_row) return;
464
+
465
+ FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
466
+
467
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
468
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
469
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
470
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
471
+
472
+ dma_queue * q = octx->ctx->dma[ith];
473
+ uint32_t ir_prefetch = start_row;
474
+ int spad_idx = 0;
475
+
476
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
477
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
478
+ uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
479
+ uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
480
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
481
+ uint32_t i01 = rem - i02 * ne01;
482
+
483
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
484
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
485
+
486
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
487
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
488
+
489
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
490
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
491
+ ir_prefetch += current_block_size;
492
+ spad_idx ^= 1;
228
493
  }
229
494
 
230
- t2 = HAP_perf_get_qtimer_count();
495
+ for (uint32_t ir = start_row; ir < end_row; ) {
496
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
497
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
498
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
499
+
500
+ uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
501
+ uint32_t rem = ir - i03 * (ne02 * ne01);
502
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
503
+ uint32_t i01 = rem - i02 * ne01;
504
+
505
+ for (uint32_t r = 0; r < current_block_size; r++) {
506
+ uint32_t r_i01 = i01 + r;
507
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
508
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
509
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
231
510
 
232
- FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
233
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
234
- src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
235
- dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
511
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
512
+ uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
513
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
514
+
515
+ // Read src1 from DDR (unaligned)
516
+ COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
517
+ }
518
+
519
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
520
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
521
+
522
+ if (ir_prefetch < end_row) {
523
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
524
+ uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
525
+ uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
526
+ uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
527
+ uint32_t p01 = prem - p02 * ne01;
528
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
529
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
530
+ ir_prefetch += next_block_size;
531
+ }
532
+ ir += current_block_size;
533
+ }
534
+ dma_queue_flush(q);
236
535
  }
237
536
 
238
- static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
239
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
537
+ // 5. Element Repeat (ne10 != ne00)
538
+ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
539
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
540
+ struct htp_ops_context * octx = bctx->octx;
541
+ htp_binary_preamble;
240
542
 
241
- switch (octx->op) {
242
- case HTP_OP_MUL:
243
- case HTP_OP_ADD:
244
- case HTP_OP_SUB:
245
- binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
246
- break;
543
+ const uint32_t src0_type = octx->src[0]->type;
544
+ const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
545
+ const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
546
+ const uint32_t total_rows = ne01 * ne02 * ne03;
547
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
548
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
549
+ if (start_row >= end_row) return;
550
+
551
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
552
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
553
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
554
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
555
+
556
+ FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
557
+
558
+ dma_queue * q = octx->ctx->dma[ith];
559
+ uint32_t ir_prefetch = start_row;
560
+ int spad_idx = 0;
561
+
562
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
563
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
564
+ uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
565
+ uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
566
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
567
+ uint32_t i01 = rem - i02 * ne01;
568
+
569
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
570
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
571
+
572
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
573
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
574
+
575
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
576
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
577
+ ir_prefetch += current_block_size;
578
+ spad_idx ^= 1;
579
+ }
247
580
 
248
- case HTP_OP_ADD_ID:
249
- binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
250
- break;
581
+ for (uint32_t ir = start_row; ir < end_row; ) {
582
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
583
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
584
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
585
+
586
+ uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
587
+ uint32_t rem = ir - i03 * (ne02 * ne01);
588
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
589
+ uint32_t i01 = rem - i02 * ne01;
590
+
591
+ for (uint32_t r = 0; r < current_block_size; r++) {
592
+ uint32_t r_i01 = i01 + r;
593
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
594
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
595
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
596
+
597
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
598
+ uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
599
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
600
+
601
+ // Repeat src1 row
602
+ for (uint32_t c = 0; c < ne00; c += ne10) {
603
+ uint32_t len = MIN(ne10, ne00 - c);
604
+ // Use UUU for speed and simplicity
605
+ COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
606
+ }
607
+ }
251
608
 
252
- default:
253
- FARF(ERROR, "Unknown Binary Op %u", octx->op);
254
- break;
609
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
610
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
611
+
612
+ if (ir_prefetch < end_row) {
613
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
614
+ uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
615
+ uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
616
+ uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
617
+ uint32_t p01 = prem - p02 * ne01;
618
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
619
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
620
+ ir_prefetch += next_block_size;
621
+ }
622
+ ir += current_block_size;
255
623
  }
624
+ dma_queue_flush(q);
256
625
  }
257
626
 
258
- static int execute_op_binary_f32(struct htp_ops_context * octx) {
259
- int err = HTP_STATUS_OK;
627
+ // 6. ADD_ID (src1 gathered via src2 indices)
628
+ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
629
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
630
+ struct htp_ops_context * octx = bctx->octx;
631
+
632
+ const struct htp_tensor * src0 = octx->src[0];
633
+ const struct htp_tensor * src1 = octx->src[1];
634
+ const struct htp_tensor * src2 = octx->src[2];
635
+ const struct htp_tensor * dst = octx->dst;
636
+
637
+ const uint32_t ne00 = src0->ne[0];
638
+ const uint32_t ne01 = src0->ne[1];
639
+ const uint32_t ne02 = src0->ne[2];
640
+ const uint32_t ne03 = src0->ne[3];
641
+ const uint32_t ne11 = src1->ne[1]; // for bounds check
642
+
643
+ const uint32_t nb01 = src0->nb[1];
644
+ const uint32_t nb02 = src0->nb[2];
645
+ const uint32_t nb03 = src0->nb[3];
646
+ const uint32_t nb11 = src1->nb[1]; // src1 row stride
647
+
648
+ const uint32_t nb1 = dst->nb[1];
649
+ const uint32_t nb2 = dst->nb[2];
650
+ const uint32_t nb3 = dst->nb[3];
651
+
652
+ const uint32_t total_rows = ne01 * ne02 * ne03;
653
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
654
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
655
+ if (start_row >= end_row) return;
656
+
657
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
658
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
659
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
660
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
661
+
662
+ dma_queue * q = octx->ctx->dma[ith];
663
+ uint32_t ir_prefetch = start_row;
664
+ int spad_idx = 0;
665
+
666
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
667
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
668
+ uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
669
+ uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
670
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
671
+ uint32_t i01 = rem - i02 * ne01;
672
+
673
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
674
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
675
+
676
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
677
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
678
+
679
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0);
680
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
681
+ ir_prefetch += current_block_size;
682
+ spad_idx ^= 1;
683
+ }
260
684
 
261
- const struct htp_tensor * src0 = &octx->src0;
262
- const struct htp_tensor * src1 = &octx->src1;
263
- struct htp_tensor * dst = &octx->dst;
685
+ for (uint32_t ir = start_row; ir < end_row; ) {
686
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
687
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
688
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
264
689
 
265
- worker_callback_t binary_op_func;
266
- const char * op_type = NULL;
690
+ uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
691
+ uint32_t rem = ir - i03 * (ne02 * ne01);
692
+ uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
693
+ uint32_t i01 = rem - i02 * ne01;
267
694
 
268
- switch (octx->op) {
269
- case HTP_OP_MUL:
270
- binary_op_func = binary_job_dispatcher_f32;
271
- op_type = "mul-f32";
272
- break;
695
+ for (uint32_t r = 0; r < current_block_size; r++) {
696
+ uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
273
697
 
274
- case HTP_OP_ADD:
275
- binary_op_func = binary_job_dispatcher_f32;
276
- op_type = "add-f32";
277
- break;
698
+ const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
278
699
 
279
- case HTP_OP_SUB:
280
- binary_op_func = binary_job_dispatcher_f32;
281
- op_type = "sub-f32";
282
- break;
700
+ uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
701
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
702
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
283
703
 
284
- case HTP_OP_ADD_ID:
285
- binary_op_func = binary_job_dispatcher_f32;
286
- op_type = "add-id-f32";
287
- break;
704
+ hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
705
+ }
288
706
 
289
- default:
290
- FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
291
- return HTP_STATUS_NO_SUPPORT;
707
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
708
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
709
+
710
+ if (ir_prefetch < end_row) {
711
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
712
+ uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
713
+ uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
714
+ uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
715
+ uint32_t p01 = prem - p02 * ne01;
716
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
717
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
718
+ ir_prefetch += next_block_size;
719
+ }
720
+ ir += current_block_size;
292
721
  }
722
+ dma_queue_flush(q);
723
+ }
724
+
725
+ static int execute_op_binary(struct htp_ops_context * octx) {
726
+ const struct htp_tensor * src0 = octx->src[0];
727
+ const struct htp_tensor * src1 = octx->src[1];
728
+ const struct htp_tensor * dst = octx->dst;
293
729
 
294
- const int n_threads = octx->n_threads;
295
730
  const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
731
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
732
+
733
+ // Use packed row sizes for VTCM allocation
734
+ const uint32_t src0_type = octx->src[0]->type;
735
+ const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
736
+ const size_t src0_row_size = src0->ne[0] * elem_size;
737
+ const size_t src1_row_size = src1->ne[0] * elem_size;
738
+ const size_t dst_row_size = dst->ne[0] * elem_size;
739
+
740
+ size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
741
+ size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
742
+ size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
743
+
744
+ bool is_add_id = (octx->op == HTP_OP_ADD_ID);
745
+ bool is_scalar = !is_add_id && (src1->ne[0] == 1);
746
+
747
+ bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size);
748
+
749
+ bool is_same_shape = !is_add_id && !is_scalar && !is_transposed &&
750
+ (src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) &&
751
+ (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
752
+ (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
753
+ (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
754
+
755
+ bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
756
+ bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]);
757
+ bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]);
758
+
759
+ size_t spad_row_total;
760
+ if (is_same_shape) {
761
+ spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
762
+ } else {
763
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
764
+ }
296
765
 
297
- const size_t src0_row_size = src0->nb[1];
298
- const size_t src1_row_size = src1->nb[1];
299
- const size_t dst_row_size = dst->nb[1];
766
+ size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
300
767
 
301
- // VTCM scratchpads for all tensors
302
- octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
303
- octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
304
- octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
768
+ // Adjust for static src1 in row_bcast case
769
+ if (is_row_bcast) {
770
+ size_t needed_static = src1_row_size_aligned;
771
+ if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
772
+ size_t avail = octx->ctx->vtcm_size - needed_static;
773
+ rows_per_buffer = avail / (n_threads * spad_row_total);
774
+ }
305
775
 
306
- size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
776
+ if (rows_per_buffer < 1) {
777
+ FARF(ERROR, "binary: VTCM too small\n");
778
+ return HTP_STATUS_VTCM_TOO_SMALL;
779
+ }
307
780
 
308
- FARF(HIGH,
309
- "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
310
- op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
311
- src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
312
- octx->dst_spad.size);
781
+ octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
782
+ octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
313
783
 
314
- // Make sure the reserved vtcm size is sufficient
315
- if (octx->ctx->vtcm_size < spad_size) {
316
- FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
317
- octx->ctx->vtcm_size, spad_size);
318
- return HTP_STATUS_VTCM_TOO_SMALL;
784
+ if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) {
785
+ octx->src1_spad.size_per_thread = 0;
786
+ } else {
787
+ octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
319
788
  }
320
789
 
321
- octx->src0_spad.data = octx->ctx->vtcm_base;
322
- octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
323
- octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
790
+ octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
791
+ octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
792
+ if (is_row_bcast) {
793
+ octx->src1_spad.size = src1_row_size_aligned;
794
+ } else {
795
+ octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
796
+ }
324
797
 
325
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
326
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
798
+ if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
799
+ return HTP_STATUS_VTCM_TOO_SMALL;
800
+ }
327
801
 
328
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
802
+ octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
803
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
804
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
329
805
 
330
- octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
331
- octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
332
- octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
333
- octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
806
+ if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
807
+ return HTP_STATUS_OK;
808
+ }
334
809
 
335
- octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
336
- octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
337
- octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
338
- octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
810
+ dma_queue * q = octx->ctx->dma[0];
811
+ if (is_row_bcast) {
812
+ dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
813
+ }
339
814
 
340
- worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
815
+ struct htp_binary_context bctx;
816
+ bctx.octx = octx;
817
+ bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
818
+ bctx.block_max = rows_per_buffer;
819
+ bctx.src0_row_size_aligned = src0_row_size_aligned;
820
+ bctx.src1_row_size_aligned = src1_row_size_aligned;
821
+ bctx.dst_row_size_aligned = dst_row_size_aligned;
822
+
823
+ bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]);
824
+ bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]);
825
+ bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
826
+
827
+ bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
828
+ bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
829
+ bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
830
+
831
+ bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
832
+ bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
833
+
834
+ bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
835
+ bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
836
+
837
+ bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
838
+ bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
839
+
840
+ worker_callback_t worker_func;
841
+ if (is_add_id) worker_func = binary_job_add_id;
842
+ else if (is_scalar) worker_func = binary_job_scalar;
843
+ else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
844
+ else if (is_same_shape) worker_func = binary_job_vector_same_shape;
845
+ else if (is_complex) worker_func = binary_job_vector_complex;
846
+ else worker_func = binary_job_element_repeat;
847
+
848
+ if (is_row_bcast) {
849
+ dma_queue_pop(q);
341
850
  }
342
851
 
343
- return err;
852
+ worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
853
+
854
+ return HTP_STATUS_OK;
344
855
  }
345
856
 
346
857
  int op_binary(struct htp_ops_context * octx) {
347
- int err = HTP_STATUS_OK;
348
858
 
349
- switch (octx->src0.type) {
350
- case HTP_TYPE_F32:
351
- err = execute_op_binary_f32(octx);
352
- break;
859
+ // Does not support permutations of src1
860
+ const struct htp_tensor * src1 = octx->src[1];
861
+ if (src1->nb[1] < src1->nb[0]) {
862
+ return HTP_STATUS_NO_SUPPORT;
863
+ }
353
864
 
354
- default:
355
- err = HTP_STATUS_NO_SUPPORT;
356
- break;
865
+ const uint32_t src0_type = octx->src[0]->type;
866
+ if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
867
+ return execute_op_binary(octx);
357
868
  }
358
869
 
359
- return err;
870
+ return HTP_STATUS_NO_SUPPORT;
360
871
  }
872
+