whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -3,14 +3,14 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "simd-gemm.h"
6
7
  #include "ggml.h"
7
8
  #include "unary-ops.h"
8
9
  #include "vec.h"
9
10
 
10
- #include <cfloat>
11
11
  #include <algorithm>
12
+ #include <cfloat>
12
13
  #include <cmath>
13
- #include <functional>
14
14
 
15
15
  // ggml_compute_forward_dup
16
16
 
@@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
375
375
  const size_t rs = ne00 * type_size;
376
376
 
377
377
  if (nb00 == type_size) {
378
- // src0 is contigous on first dimension, copy by rows
378
+ // src0 is contiguous on first dimension, copy by rows
379
379
  for (int64_t i03 = 0; i03 < ne03; i03++) {
380
380
  for (int64_t i02 = 0; i02 < ne02; i02++) {
381
381
  id += rs * ir0;
@@ -664,12 +664,14 @@ void ggml_compute_forward_add(
664
664
  {
665
665
  ggml_compute_forward_add_non_quantized(params, dst);
666
666
  } break;
667
+ case GGML_TYPE_Q1_0:
667
668
  case GGML_TYPE_Q4_0:
668
669
  case GGML_TYPE_Q4_1:
669
670
  case GGML_TYPE_Q5_0:
670
671
  case GGML_TYPE_Q5_1:
671
672
  case GGML_TYPE_Q8_0:
672
673
  case GGML_TYPE_MXFP4:
674
+ case GGML_TYPE_NVFP4:
673
675
  case GGML_TYPE_Q2_K:
674
676
  case GGML_TYPE_Q3_K:
675
677
  case GGML_TYPE_Q4_K:
@@ -1112,6 +1114,7 @@ void ggml_compute_forward_add1(
1112
1114
  GGML_ABORT("fatal error");
1113
1115
  }
1114
1116
  } break;
1117
+ case GGML_TYPE_Q1_0:
1115
1118
  case GGML_TYPE_Q4_0:
1116
1119
  case GGML_TYPE_Q4_1:
1117
1120
  case GGML_TYPE_Q5_0:
@@ -1119,6 +1122,7 @@ void ggml_compute_forward_add1(
1119
1122
  case GGML_TYPE_Q8_0:
1120
1123
  case GGML_TYPE_Q8_1:
1121
1124
  case GGML_TYPE_MXFP4:
1125
+ case GGML_TYPE_NVFP4:
1122
1126
  case GGML_TYPE_Q2_K:
1123
1127
  case GGML_TYPE_Q3_K:
1124
1128
  case GGML_TYPE_Q4_K:
@@ -1240,6 +1244,7 @@ void ggml_compute_forward_acc(
1240
1244
  } break;
1241
1245
  case GGML_TYPE_F16:
1242
1246
  case GGML_TYPE_BF16:
1247
+ case GGML_TYPE_Q1_0:
1243
1248
  case GGML_TYPE_Q4_0:
1244
1249
  case GGML_TYPE_Q4_1:
1245
1250
  case GGML_TYPE_Q5_0:
@@ -1247,6 +1252,7 @@ void ggml_compute_forward_acc(
1247
1252
  case GGML_TYPE_Q8_0:
1248
1253
  case GGML_TYPE_Q8_1:
1249
1254
  case GGML_TYPE_MXFP4:
1255
+ case GGML_TYPE_NVFP4:
1250
1256
  case GGML_TYPE_Q2_K:
1251
1257
  case GGML_TYPE_Q3_K:
1252
1258
  case GGML_TYPE_Q4_K:
@@ -1795,7 +1801,7 @@ void ggml_compute_forward_repeat(
1795
1801
  {
1796
1802
  ggml_compute_forward_repeat_f32(params, dst);
1797
1803
  } break;
1798
- // TODO: templateify the implemenation and support for I64
1804
+ // TODO: templateify the implementation and support for I64
1799
1805
  // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1800
1806
  //case GGML_TYPE_I64:
1801
1807
  // {
@@ -2097,10 +2103,14 @@ static void ggml_compute_forward_gelu_f32(
2097
2103
 
2098
2104
  const ggml_tensor * src0 = dst->src[0];
2099
2105
 
2100
- assert(ggml_is_contiguous_1(src0));
2101
- assert(ggml_is_contiguous_1(dst));
2106
+ assert(ggml_is_contiguous_rows(src0));
2102
2107
  assert(ggml_are_same_shape(src0, dst));
2103
2108
 
2109
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2110
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2111
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2112
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2113
+
2104
2114
  const int ith = params->ith;
2105
2115
  const int nth = params->nth;
2106
2116
 
@@ -2114,19 +2124,23 @@ static void ggml_compute_forward_gelu_f32(
2114
2124
  const int ir0 = dr*ith;
2115
2125
  const int ir1 = MIN(ir0 + dr, nr);
2116
2126
 
2117
- for (int i1 = ir0; i1 < ir1; i1++) {
2127
+ for (int ir = ir0; ir < ir1; ++ir) {
2128
+ const int i3 = ir/(ne02*ne01);
2129
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2130
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2131
+
2118
2132
  ggml_vec_gelu_f32(nc,
2119
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2120
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2133
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2134
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2121
2135
 
2122
2136
  #ifndef NDEBUG
2123
2137
  for (int k = 0; k < nc; k++) {
2124
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2138
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2125
2139
  GGML_UNUSED(x);
2126
2140
  assert(!isnan(x));
2127
2141
  assert(!isinf(x));
2128
2142
  }
2129
- #endif
2143
+ #endif // NDEBUG
2130
2144
  }
2131
2145
  }
2132
2146
 
@@ -2136,10 +2150,14 @@ static void ggml_compute_forward_gelu_f16(
2136
2150
 
2137
2151
  const ggml_tensor * src0 = dst->src[0];
2138
2152
 
2139
- assert(ggml_is_contiguous_1(src0));
2140
- assert(ggml_is_contiguous_1(dst));
2153
+ assert(ggml_is_contiguous_rows(src0));
2141
2154
  assert(ggml_are_same_shape(src0, dst));
2142
2155
 
2156
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2157
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2158
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2159
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2160
+
2143
2161
  const int ith = params->ith;
2144
2162
  const int nth = params->nth;
2145
2163
 
@@ -2153,20 +2171,24 @@ static void ggml_compute_forward_gelu_f16(
2153
2171
  const int ir0 = dr*ith;
2154
2172
  const int ir1 = MIN(ir0 + dr, nr);
2155
2173
 
2156
- for (int i1 = ir0; i1 < ir1; i1++) {
2174
+ for (int ir = ir0; ir < ir1; ++ir) {
2175
+ const int i3 = ir/(ne02*ne01);
2176
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2177
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2178
+
2157
2179
  ggml_vec_gelu_f16(nc,
2158
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2159
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2180
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2181
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2160
2182
 
2161
2183
  #ifndef NDEBUG
2162
2184
  for (int k = 0; k < nc; k++) {
2163
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2185
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2164
2186
  const float v = GGML_CPU_FP16_TO_FP32(x);
2165
2187
  GGML_UNUSED(v);
2166
2188
  assert(!isnan(v));
2167
2189
  assert(!isinf(v));
2168
2190
  }
2169
- #endif
2191
+ #endif // NDEBUG
2170
2192
  }
2171
2193
  }
2172
2194
 
@@ -2213,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg
2213
2235
  }
2214
2236
  }
2215
2237
 
2238
+ static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
2239
+ const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
2240
+
2241
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2242
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2243
+
2244
+ const auto [ir0, ir1] = get_thread_range(params, dst);
2245
+
2246
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
2247
+ const int64_t i03 = ir/(ne2*ne1);
2248
+ const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2249
+ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2250
+
2251
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2252
+
2253
+ ggml_vec_set_f16(ne0, dst_ptr, c);
2254
+ }
2255
+ }
2256
+
2216
2257
  void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2217
- ggml_compute_forward_fill_f32(params, dst);
2258
+ const ggml_tensor * src0 = dst->src[0];
2259
+
2260
+ switch (src0->type) {
2261
+ case GGML_TYPE_F32:
2262
+ {
2263
+ ggml_compute_forward_fill_f32(params, dst);
2264
+ } break;
2265
+ case GGML_TYPE_F16:
2266
+ {
2267
+ ggml_compute_forward_fill_f16(params, dst);
2268
+ } break;
2269
+ default:
2270
+ {
2271
+ GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
2272
+ }
2273
+ }
2218
2274
  }
2219
2275
 
2220
2276
  // ggml_compute_tri
@@ -2277,10 +2333,14 @@ static void ggml_compute_forward_gelu_erf_f32(
2277
2333
 
2278
2334
  const ggml_tensor * src0 = dst->src[0];
2279
2335
 
2280
- assert(ggml_is_contiguous_1(src0));
2281
- assert(ggml_is_contiguous_1(dst));
2336
+ assert(ggml_is_contiguous_rows(src0));
2282
2337
  assert(ggml_are_same_shape(src0, dst));
2283
2338
 
2339
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2340
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2341
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2342
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2343
+
2284
2344
  const int ith = params->ith;
2285
2345
  const int nth = params->nth;
2286
2346
 
@@ -2294,19 +2354,23 @@ static void ggml_compute_forward_gelu_erf_f32(
2294
2354
  const int ir0 = dr*ith;
2295
2355
  const int ir1 = MIN(ir0 + dr, nr);
2296
2356
 
2297
- for (int i1 = ir0; i1 < ir1; i1++) {
2357
+ for (int ir = ir0; ir < ir1; ++ir) {
2358
+ const int i3 = ir/(ne02*ne01);
2359
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2360
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2361
+
2298
2362
  ggml_vec_gelu_erf_f32(nc,
2299
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2300
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2363
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2364
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2301
2365
 
2302
2366
  #ifndef NDEBUG
2303
2367
  for (int k = 0; k < nc; k++) {
2304
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2368
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2305
2369
  GGML_UNUSED(x);
2306
2370
  assert(!isnan(x));
2307
2371
  assert(!isinf(x));
2308
2372
  }
2309
- #endif
2373
+ #endif // NDEBUG
2310
2374
  }
2311
2375
  }
2312
2376
 
@@ -2316,10 +2380,14 @@ static void ggml_compute_forward_gelu_erf_f16(
2316
2380
 
2317
2381
  const ggml_tensor * src0 = dst->src[0];
2318
2382
 
2319
- assert(ggml_is_contiguous_1(src0));
2320
- assert(ggml_is_contiguous_1(dst));
2383
+ assert(ggml_is_contiguous_rows(src0));
2321
2384
  assert(ggml_are_same_shape(src0, dst));
2322
2385
 
2386
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2387
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2388
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2389
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2390
+
2323
2391
  const int ith = params->ith;
2324
2392
  const int nth = params->nth;
2325
2393
 
@@ -2333,20 +2401,24 @@ static void ggml_compute_forward_gelu_erf_f16(
2333
2401
  const int ir0 = dr*ith;
2334
2402
  const int ir1 = MIN(ir0 + dr, nr);
2335
2403
 
2336
- for (int i1 = ir0; i1 < ir1; i1++) {
2404
+ for (int ir = ir0; ir < ir1; ++ir) {
2405
+ const int i3 = ir/(ne02*ne01);
2406
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2407
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2408
+
2337
2409
  ggml_vec_gelu_erf_f16(nc,
2338
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2339
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2410
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2411
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2340
2412
 
2341
2413
  #ifndef NDEBUG
2342
2414
  for (int k = 0; k < nc; k++) {
2343
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2415
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2344
2416
  const float v = GGML_CPU_FP16_TO_FP32(x);
2345
2417
  GGML_UNUSED(v);
2346
2418
  assert(!isnan(v));
2347
2419
  assert(!isinf(v));
2348
2420
  }
2349
- #endif
2421
+ #endif // NDEBUG
2350
2422
  }
2351
2423
  }
2352
2424
 
@@ -2380,10 +2452,14 @@ static void ggml_compute_forward_gelu_quick_f32(
2380
2452
 
2381
2453
  const ggml_tensor * src0 = dst->src[0];
2382
2454
 
2383
- assert(ggml_is_contiguous_1(src0));
2384
- assert(ggml_is_contiguous_1(dst));
2455
+ assert(ggml_is_contiguous_rows(src0));
2385
2456
  assert(ggml_are_same_shape(src0, dst));
2386
2457
 
2458
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2459
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2460
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2461
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2462
+
2387
2463
  const int ith = params->ith;
2388
2464
  const int nth = params->nth;
2389
2465
 
@@ -2397,19 +2473,23 @@ static void ggml_compute_forward_gelu_quick_f32(
2397
2473
  const int ir0 = dr*ith;
2398
2474
  const int ir1 = MIN(ir0 + dr, nr);
2399
2475
 
2400
- for (int i1 = ir0; i1 < ir1; i1++) {
2476
+ for (int ir = ir0; ir < ir1; ++ir) {
2477
+ const int i3 = ir/(ne02*ne01);
2478
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2479
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2480
+
2401
2481
  ggml_vec_gelu_quick_f32(nc,
2402
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2403
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2482
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2483
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2404
2484
 
2405
2485
  #ifndef NDEBUG
2406
2486
  for (int k = 0; k < nc; k++) {
2407
- const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2487
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2408
2488
  GGML_UNUSED(x);
2409
2489
  assert(!isnan(x));
2410
2490
  assert(!isinf(x));
2411
2491
  }
2412
- #endif
2492
+ #endif // NDEBUG
2413
2493
  }
2414
2494
  }
2415
2495
 
@@ -2419,10 +2499,14 @@ static void ggml_compute_forward_gelu_quick_f16(
2419
2499
 
2420
2500
  const ggml_tensor * src0 = dst->src[0];
2421
2501
 
2422
- assert(ggml_is_contiguous_1(src0));
2423
- assert(ggml_is_contiguous_1(dst));
2502
+ assert(ggml_is_contiguous_rows(src0));
2424
2503
  assert(ggml_are_same_shape(src0, dst));
2425
2504
 
2505
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2506
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2507
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2508
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2509
+
2426
2510
  const int ith = params->ith;
2427
2511
  const int nth = params->nth;
2428
2512
 
@@ -2436,20 +2520,24 @@ static void ggml_compute_forward_gelu_quick_f16(
2436
2520
  const int ir0 = dr*ith;
2437
2521
  const int ir1 = MIN(ir0 + dr, nr);
2438
2522
 
2439
- for (int i1 = ir0; i1 < ir1; i1++) {
2523
+ for (int ir = ir0; ir < ir1; ++ir) {
2524
+ const int i3 = ir/(ne02*ne01);
2525
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2526
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2527
+
2440
2528
  ggml_vec_gelu_quick_f16(nc,
2441
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2442
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2529
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2530
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2443
2531
 
2444
2532
  #ifndef NDEBUG
2445
2533
  for (int k = 0; k < nc; k++) {
2446
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2534
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2447
2535
  const float v = GGML_CPU_FP16_TO_FP32(x);
2448
2536
  GGML_UNUSED(v);
2449
2537
  assert(!isnan(v));
2450
2538
  assert(!isinf(v));
2451
2539
  }
2452
- #endif
2540
+ #endif // NDEBUG
2453
2541
  }
2454
2542
  }
2455
2543
 
@@ -2483,10 +2571,14 @@ static void ggml_compute_forward_silu_f32(
2483
2571
 
2484
2572
  const ggml_tensor * src0 = dst->src[0];
2485
2573
 
2486
- assert(ggml_is_contiguous_1(src0));
2487
- assert(ggml_is_contiguous_1(dst));
2574
+ assert(ggml_is_contiguous_rows(src0));
2488
2575
  assert(ggml_are_same_shape(src0, dst));
2489
2576
 
2577
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2578
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2579
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2580
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2581
+
2490
2582
  const int ith = params->ith;
2491
2583
  const int nth = params->nth;
2492
2584
 
@@ -2500,19 +2592,23 @@ static void ggml_compute_forward_silu_f32(
2500
2592
  const int ir0 = dr*ith;
2501
2593
  const int ir1 = MIN(ir0 + dr, nr);
2502
2594
 
2503
- for (int i1 = ir0; i1 < ir1; i1++) {
2595
+ for (int ir = ir0; ir < ir1; ++ir) {
2596
+ const int i3 = ir/(ne02*ne01);
2597
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2598
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2599
+
2504
2600
  ggml_vec_silu_f32(nc,
2505
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
2506
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
2601
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2602
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2507
2603
 
2508
2604
  #ifndef NDEBUG
2509
2605
  for (int k = 0; k < nc; k++) {
2510
- const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2606
+ const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2511
2607
  GGML_UNUSED(x);
2512
2608
  assert(!isnan(x));
2513
2609
  assert(!isinf(x));
2514
2610
  }
2515
- #endif
2611
+ #endif // NDEBUG
2516
2612
  }
2517
2613
  }
2518
2614
 
@@ -2522,10 +2618,14 @@ static void ggml_compute_forward_silu_f16(
2522
2618
 
2523
2619
  const ggml_tensor * src0 = dst->src[0];
2524
2620
 
2525
- assert(ggml_is_contiguous_1(src0));
2526
- assert(ggml_is_contiguous_1(dst));
2621
+ assert(ggml_is_contiguous_rows(src0));
2527
2622
  assert(ggml_are_same_shape(src0, dst));
2528
2623
 
2624
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2625
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2626
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2627
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2628
+
2529
2629
  const int ith = params->ith;
2530
2630
  const int nth = params->nth;
2531
2631
 
@@ -2539,20 +2639,24 @@ static void ggml_compute_forward_silu_f16(
2539
2639
  const int ir0 = dr*ith;
2540
2640
  const int ir1 = MIN(ir0 + dr, nr);
2541
2641
 
2542
- for (int i1 = ir0; i1 < ir1; i1++) {
2642
+ for (int ir = ir0; ir < ir1; ++ir) {
2643
+ const int i3 = ir/(ne02*ne01);
2644
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
2645
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2646
+
2543
2647
  ggml_vec_silu_f16(nc,
2544
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2545
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2648
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2649
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2546
2650
 
2547
2651
  #ifndef NDEBUG
2548
2652
  for (int k = 0; k < nc; k++) {
2549
- const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2653
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2550
2654
  const float v = GGML_CPU_FP16_TO_FP32(x);
2551
2655
  GGML_UNUSED(v);
2552
2656
  assert(!isnan(v));
2553
2657
  assert(!isinf(v));
2554
2658
  }
2555
- #endif
2659
+ #endif // NDEBUG
2556
2660
  }
2557
2661
  }
2558
2662
 
@@ -2702,7 +2806,7 @@ static void ggml_compute_forward_silu_back_f32(
2702
2806
  assert(!isnan(x));
2703
2807
  assert(!isinf(x));
2704
2808
  }
2705
- #endif
2809
+ #endif // NDEBUG
2706
2810
  }
2707
2811
  }
2708
2812
 
@@ -2738,7 +2842,7 @@ static void ggml_compute_forward_silu_back_f16(
2738
2842
  (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2739
2843
  (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2740
2844
 
2741
- #ifndef NDEBUG
2845
+ #ifndef NDEBUG
2742
2846
  for (int k = 0; k < nc; k++) {
2743
2847
  const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2744
2848
  const float v = GGML_CPU_FP16_TO_FP32(x);
@@ -2746,7 +2850,7 @@ static void ggml_compute_forward_silu_back_f16(
2746
2850
  assert(!isnan(v));
2747
2851
  assert(!isinf(v));
2748
2852
  }
2749
- #endif
2853
+ #endif // NDEBUG
2750
2854
  }
2751
2855
  }
2752
2856
 
@@ -2829,7 +2933,7 @@ static void ggml_compute_forward_reglu_f32(
2829
2933
  assert(!isnan(x));
2830
2934
  assert(!isinf(x));
2831
2935
  }
2832
- #endif
2936
+ #endif // NDEBUG
2833
2937
  }
2834
2938
  }
2835
2939
 
@@ -2889,7 +2993,7 @@ static void ggml_compute_forward_reglu_f16(
2889
2993
  assert(!isnan(v));
2890
2994
  assert(!isinf(v));
2891
2995
  }
2892
- #endif
2996
+ #endif // NDEBUG
2893
2997
  }
2894
2998
  }
2895
2999
 
@@ -2972,7 +3076,7 @@ static void ggml_compute_forward_geglu_f32(
2972
3076
  assert(!isnan(x));
2973
3077
  assert(!isinf(x));
2974
3078
  }
2975
- #endif
3079
+ #endif // NDEBUG
2976
3080
  }
2977
3081
  }
2978
3082
 
@@ -3032,7 +3136,7 @@ static void ggml_compute_forward_geglu_f16(
3032
3136
  assert(!isnan(v));
3033
3137
  assert(!isinf(v));
3034
3138
  }
3035
- #endif
3139
+ #endif // NDEBUG
3036
3140
  }
3037
3141
  }
3038
3142
 
@@ -3115,7 +3219,7 @@ static void ggml_compute_forward_swiglu_f32(
3115
3219
  assert(!isnan(x));
3116
3220
  assert(!isinf(x));
3117
3221
  }
3118
- #endif
3222
+ #endif // NDEBUG
3119
3223
  }
3120
3224
  }
3121
3225
 
@@ -3175,7 +3279,7 @@ static void ggml_compute_forward_swiglu_f16(
3175
3279
  assert(!isnan(v));
3176
3280
  assert(!isinf(v));
3177
3281
  }
3178
- #endif
3282
+ #endif // NDEBUG
3179
3283
  }
3180
3284
  }
3181
3285
 
@@ -3266,7 +3370,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
3266
3370
  assert(!isnan(x));
3267
3371
  assert(!isinf(x));
3268
3372
  }
3269
- #endif
3373
+ #endif // NDEBUG
3270
3374
  }
3271
3375
  }
3272
3376
 
@@ -3345,7 +3449,7 @@ static void ggml_compute_forward_geglu_erf_f32(
3345
3449
  assert(!isnan(x));
3346
3450
  assert(!isinf(x));
3347
3451
  }
3348
- #endif
3452
+ #endif // NDEBUG
3349
3453
  }
3350
3454
  }
3351
3455
 
@@ -3405,7 +3509,7 @@ static void ggml_compute_forward_geglu_erf_f16(
3405
3509
  assert(!isnan(v));
3406
3510
  assert(!isinf(v));
3407
3511
  }
3408
- #endif
3512
+ #endif // NDEBUG
3409
3513
  }
3410
3514
  }
3411
3515
 
@@ -3488,7 +3592,7 @@ static void ggml_compute_forward_geglu_quick_f32(
3488
3592
  assert(!isnan(x));
3489
3593
  assert(!isinf(x));
3490
3594
  }
3491
- #endif
3595
+ #endif // NDEBUG
3492
3596
  }
3493
3597
  }
3494
3598
 
@@ -3548,7 +3652,7 @@ static void ggml_compute_forward_geglu_quick_f16(
3548
3652
  assert(!isnan(v));
3549
3653
  assert(!isinf(v));
3550
3654
  }
3551
- #endif
3655
+ #endif // NDEBUG
3552
3656
  }
3553
3657
  }
3554
3658
 
@@ -3643,11 +3747,27 @@ void ggml_compute_forward_norm(
3643
3747
 
3644
3748
  // ggml_compute_forward_group_rms_norm
3645
3749
 
3750
+ // fusion kinds that can be combined with the rms_norm computation in a single pass.
3751
+ // extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
3752
+ enum ggml_rms_norm_fuse_op {
3753
+ GGML_RMS_NORM_FUSE_OP_NONE,
3754
+ GGML_RMS_NORM_FUSE_OP_MUL,
3755
+ };
3756
+
3757
+ template <ggml_rms_norm_fuse_op FUSE_OP>
3646
3758
  static void ggml_compute_forward_rms_norm_f32(
3647
3759
  const ggml_compute_params * params,
3648
- ggml_tensor * dst) {
3760
+ ggml_tensor * dst_rms_norm,
3761
+ ggml_tensor * dst_fused = nullptr) {
3649
3762
 
3650
- const ggml_tensor * src0 = dst->src[0];
3763
+ const ggml_tensor * src0 = dst_rms_norm->src[0];
3764
+ const ggml_tensor * src1 = nullptr;
3765
+ ggml_tensor * dst = dst_rms_norm;
3766
+
3767
+ if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3768
+ src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
3769
+ dst = dst_fused;
3770
+ }
3651
3771
 
3652
3772
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
3653
3773
 
@@ -3656,11 +3776,10 @@ static void ggml_compute_forward_rms_norm_f32(
3656
3776
  const int ith = params->ith;
3657
3777
  const int nth = params->nth;
3658
3778
 
3659
- GGML_TENSOR_UNARY_OP_LOCALS
3779
+ GGML_TENSOR_BINARY_OP_LOCALS
3660
3780
 
3661
3781
  float eps;
3662
- memcpy(&eps, dst->op_params, sizeof(float));
3663
-
3782
+ memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
3664
3783
  GGML_ASSERT(eps >= 0.0f);
3665
3784
 
3666
3785
  // TODO: optimize
@@ -3670,25 +3789,32 @@ static void ggml_compute_forward_rms_norm_f32(
3670
3789
  const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3671
3790
 
3672
3791
  ggml_float sum = 0.0;
3792
+ // worth switching to explicit SIMD?
3673
3793
  for (int64_t i00 = 0; i00 < ne00; i00++) {
3674
3794
  sum += (ggml_float)(x[i00] * x[i00]);
3675
3795
  }
3676
3796
 
3677
- const float mean = sum/ne00;
3678
-
3679
- float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3680
-
3681
- memcpy(y, x, ne00 * sizeof(float));
3682
- // for (int i00 = 0; i00 < ne00; i00++) {
3683
- // y[i00] = x[i00];
3684
- // }
3685
-
3797
+ const float mean = sum/ne00;
3686
3798
  const float scale = 1.0f/sqrtf(mean + eps);
3687
3799
 
3688
3800
  // if you hit this, likely you got an inf somewhere earlier
3689
3801
  assert(scale > 0.0f);
3690
3802
 
3691
- ggml_vec_scale_f32(ne00, y, scale);
3803
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3804
+
3805
+ if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3806
+ const int64_t i11 = i01 % ne11;
3807
+ const int64_t i12 = i02 % ne12;
3808
+ const int64_t i13 = i03 % ne13;
3809
+ const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3810
+
3811
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
3812
+ y[i00] = x[i00] * scale * w[i00];
3813
+ }
3814
+ } else {
3815
+ memcpy(y, x, ne00 * sizeof(float));
3816
+ ggml_vec_scale_f32(ne00, y, scale);
3817
+ }
3692
3818
  }
3693
3819
  }
3694
3820
  }
@@ -3703,7 +3829,31 @@ void ggml_compute_forward_rms_norm(
3703
3829
  switch (src0->type) {
3704
3830
  case GGML_TYPE_F32:
3705
3831
  {
3706
- ggml_compute_forward_rms_norm_f32(params, dst);
3832
+ ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
3833
+ } break;
3834
+ default:
3835
+ {
3836
+ GGML_ABORT("fatal error");
3837
+ }
3838
+ }
3839
+ }
3840
+
3841
+ // Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
3842
+ // This avoids materializing the intermediate rms_norm result in memory.
3843
+ void ggml_compute_forward_rms_norm_mul_fused(
3844
+ const ggml_compute_params * params,
3845
+ ggml_tensor * dst_rms_norm,
3846
+ ggml_tensor * dst_mul) {
3847
+
3848
+ GGML_ASSERT(dst_mul != nullptr);
3849
+ GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
3850
+
3851
+ const ggml_tensor * src0 = dst_rms_norm->src[0];
3852
+
3853
+ switch (src0->type) {
3854
+ case GGML_TYPE_F32:
3855
+ {
3856
+ ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
3707
3857
  } break;
3708
3858
  default:
3709
3859
  {
@@ -3858,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32(
3858
4008
  // dx := scale(dx, rrms)
3859
4009
  float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3860
4010
 
3861
- // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
3862
- ggml_vec_cpy_f32 (ne00, dx, x);
3863
- // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
3864
- ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
3865
- ggml_vec_acc_f32 (ne00, dx, dz);
3866
- ggml_vec_scale_f32(ne00, dx, rrms);
4011
+ // dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms
4012
+ // note: https://github.com/ggml-org/ggml/issues/1491
4013
+ const float scale_x = (float) (-sum_xdz) / sum_eps;
4014
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
4015
+ dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms;
4016
+ }
3867
4017
  }
3868
4018
  }
3869
4019
  }
@@ -4264,12 +4414,14 @@ void ggml_compute_forward_out_prod(
4264
4414
  const ggml_tensor * src0 = dst->src[0];
4265
4415
 
4266
4416
  switch (src0->type) {
4417
+ case GGML_TYPE_Q1_0:
4267
4418
  case GGML_TYPE_Q4_0:
4268
4419
  case GGML_TYPE_Q4_1:
4269
4420
  case GGML_TYPE_Q5_0:
4270
4421
  case GGML_TYPE_Q5_1:
4271
4422
  case GGML_TYPE_Q8_0:
4272
4423
  case GGML_TYPE_MXFP4:
4424
+ case GGML_TYPE_NVFP4:
4273
4425
  case GGML_TYPE_Q2_K:
4274
4426
  case GGML_TYPE_Q3_K:
4275
4427
  case GGML_TYPE_Q4_K:
@@ -4538,6 +4690,7 @@ void ggml_compute_forward_set(
4538
4690
  } break;
4539
4691
  case GGML_TYPE_F16:
4540
4692
  case GGML_TYPE_BF16:
4693
+ case GGML_TYPE_Q1_0:
4541
4694
  case GGML_TYPE_Q4_0:
4542
4695
  case GGML_TYPE_Q4_1:
4543
4696
  case GGML_TYPE_Q5_0:
@@ -4545,6 +4698,7 @@ void ggml_compute_forward_set(
4545
4698
  case GGML_TYPE_Q8_0:
4546
4699
  case GGML_TYPE_Q8_1:
4547
4700
  case GGML_TYPE_MXFP4:
4701
+ case GGML_TYPE_NVFP4:
4548
4702
  case GGML_TYPE_Q2_K:
4549
4703
  case GGML_TYPE_Q3_K:
4550
4704
  case GGML_TYPE_Q4_K:
@@ -4760,6 +4914,7 @@ void ggml_compute_forward_get_rows(
4760
4914
  const ggml_tensor * src0 = dst->src[0];
4761
4915
 
4762
4916
  switch (src0->type) {
4917
+ case GGML_TYPE_Q1_0:
4763
4918
  case GGML_TYPE_Q4_0:
4764
4919
  case GGML_TYPE_Q4_1:
4765
4920
  case GGML_TYPE_Q5_0:
@@ -4767,6 +4922,7 @@ void ggml_compute_forward_get_rows(
4767
4922
  case GGML_TYPE_Q8_0:
4768
4923
  case GGML_TYPE_Q8_1:
4769
4924
  case GGML_TYPE_MXFP4:
4925
+ case GGML_TYPE_NVFP4:
4770
4926
  case GGML_TYPE_Q2_K:
4771
4927
  case GGML_TYPE_Q3_K:
4772
4928
  case GGML_TYPE_Q4_K:
@@ -5239,7 +5395,7 @@ static void ggml_compute_forward_soft_max_f32(
5239
5395
  //printf("p[%d] = %f\n", i, p[i]);
5240
5396
  assert(!isnan(wp[i]));
5241
5397
  }
5242
- #endif
5398
+ #endif // NDEBUG
5243
5399
 
5244
5400
  float max = -INFINITY;
5245
5401
  ggml_vec_max_f32(ne00, &max, wp);
@@ -5264,7 +5420,7 @@ static void ggml_compute_forward_soft_max_f32(
5264
5420
  assert(!isnan(dp[i]));
5265
5421
  assert(!isinf(dp[i]));
5266
5422
  }
5267
- #endif
5423
+ #endif // NDEBUG
5268
5424
  }
5269
5425
  }
5270
5426
  }
@@ -5338,7 +5494,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
5338
5494
  assert(!isnan(dy[i]));
5339
5495
  assert(!isnan(y[i]));
5340
5496
  }
5341
- #endif
5497
+ #endif // NDEBUG
5342
5498
  // Jii = yi - yi*yi
5343
5499
  // Jij = -yi*yj
5344
5500
  // J = diag(y)-y.T*y
@@ -5371,7 +5527,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
5371
5527
  assert(!isnan(dx[i]));
5372
5528
  assert(!isinf(dx[i]));
5373
5529
  }
5374
- #endif
5530
+ #endif // NDEBUG
5375
5531
  }
5376
5532
  }
5377
5533
 
@@ -5484,6 +5640,7 @@ void ggml_compute_forward_clamp(
5484
5640
  ggml_compute_forward_clamp_f16(params, dst);
5485
5641
  } break;
5486
5642
  case GGML_TYPE_BF16:
5643
+ case GGML_TYPE_Q1_0:
5487
5644
  case GGML_TYPE_Q4_0:
5488
5645
  case GGML_TYPE_Q4_1:
5489
5646
  case GGML_TYPE_Q5_0:
@@ -5491,6 +5648,7 @@ void ggml_compute_forward_clamp(
5491
5648
  case GGML_TYPE_Q8_0:
5492
5649
  case GGML_TYPE_Q8_1:
5493
5650
  case GGML_TYPE_MXFP4:
5651
+ case GGML_TYPE_NVFP4:
5494
5652
  case GGML_TYPE_Q2_K:
5495
5653
  case GGML_TYPE_Q3_K:
5496
5654
  case GGML_TYPE_Q4_K:
@@ -5739,28 +5897,33 @@ static void ggml_compute_forward_rope_flt(
5739
5897
 
5740
5898
  const int32_t * pos = (const int32_t *) src1->data;
5741
5899
 
5900
+ int64_t last_i2 = -1;
5901
+
5742
5902
  for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
5903
  for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5744
-
5745
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5746
- if (!mrope_used) {
5747
- const int64_t p = pos[i2];
5748
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5749
- }
5750
- else {
5751
- const int64_t p_t = pos[i2];
5752
- const int64_t p_h = pos[i2 + ne2];
5753
- const int64_t p_w = pos[i2 + ne2 * 2];
5754
- const int64_t p_e = pos[i2 + ne2 * 3];
5755
- ggml_mrope_cache_init(
5756
- p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5757
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5758
- }
5759
-
5760
5904
  for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5761
- if (ir++ < ir0) continue;
5905
+ if (ir++ < ir0) continue; // skip rows mapped to other threads
5762
5906
  if (ir > ir1) break;
5763
5907
 
5908
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5909
+ if (last_i2 != i2) {
5910
+ if (!mrope_used) {
5911
+ const int64_t p = pos[i2];
5912
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5913
+ }
5914
+ else {
5915
+ const int64_t p_t = pos[i2];
5916
+ const int64_t p_h = pos[i2 + ne2];
5917
+ const int64_t p_w = pos[i2 + ne2 * 2];
5918
+ const int64_t p_e = pos[i2 + ne2 * 3];
5919
+ ggml_mrope_cache_init(
5920
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5921
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5922
+ }
5923
+
5924
+ last_i2 = i2;
5925
+ }
5926
+
5764
5927
  T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
5928
  T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5766
5929
 
@@ -6129,7 +6292,7 @@ static void ggml_compute_forward_im2col_f16(
6129
6292
  const ggml_tensor * src1 = dst->src[1];
6130
6293
 
6131
6294
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
6132
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
6295
+ GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
6133
6296
  GGML_ASSERT( dst->type == GGML_TYPE_F16);
6134
6297
 
6135
6298
  GGML_TENSOR_BINARY_OP_LOCALS;
@@ -6160,7 +6323,7 @@ static void ggml_compute_forward_im2col_f16(
6160
6323
  int ofs1 = is_2D ? nb12 : nb11;
6161
6324
 
6162
6325
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6163
- GGML_ASSERT(nb10 == sizeof(float));
6326
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
6164
6327
 
6165
6328
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6166
6329
  {
@@ -6173,7 +6336,12 @@ static void ggml_compute_forward_im2col_f16(
6173
6336
 
6174
6337
  // micro kernel
6175
6338
  ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6176
- const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6339
+ const float * const src_data_f32 = src1->type == GGML_TYPE_F32
6340
+ ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6341
+ : nullptr; // [IH, IW]
6342
+ const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
6343
+ ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6344
+ : nullptr; // [IH, IW]
6177
6345
 
6178
6346
  for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6179
6347
  for (int64_t ikw = 0; ikw < KW; ikw++) {
@@ -6183,7 +6351,11 @@ static void ggml_compute_forward_im2col_f16(
6183
6351
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6184
6352
  dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6185
6353
  } else {
6186
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6354
+ if (src_data_f32 != nullptr) {
6355
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
6356
+ } else {
6357
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
6358
+ }
6187
6359
  }
6188
6360
  }
6189
6361
  }
@@ -6558,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6558
6730
  return (coord + size) % size; // adding size avoids negative number weirdness
6559
6731
  }
6560
6732
 
6733
+ // ggml_compute_forward_col2im_1d
6734
+ //
6735
+ // Scatter-add columns [K*OC, T_in] -> signal [T_out, OC]
6736
+ // where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs.
6737
+ // Parallelized over the time axis so the split stays balanced whatever OC is.
6738
+ // Supports F32, F16, BF16 input/output (same type), F32 accumulator.
6739
+
6740
+ template <typename elem_t>
6741
+ static void ggml_compute_forward_col2im_1d_impl(
6742
+ const ggml_compute_params * params,
6743
+ ggml_tensor * dst) {
6744
+
6745
+ const ggml_tensor * src = dst->src[0]; // [K*OC, T_in]
6746
+
6747
+ GGML_ASSERT(ggml_is_contiguous(src));
6748
+ GGML_ASSERT(ggml_is_contiguous(dst));
6749
+
6750
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6751
+ const int32_t OC = ((const int32_t *)(dst->op_params))[1];
6752
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6753
+
6754
+ const int64_t K_OC = src->ne[0];
6755
+ const int64_t T_in = src->ne[1];
6756
+ const int64_t K = K_OC / OC;
6757
+ const int64_t T_out = dst->ne[0];
6758
+
6759
+ const elem_t * col_data = (const elem_t *) src->data;
6760
+ elem_t * dst_data = (elem_t *) dst->data;
6761
+
6762
+ const int ith = params->ith;
6763
+ const int nth = params->nth;
6764
+
6765
+ // Parallelize over the time axis: the split stays balanced whatever OC is,
6766
+ // down to OC = 1 for mono audio, and threads read disjoint column bands
6767
+ const int64_t dr = (T_out + nth - 1) / nth;
6768
+ const int64_t it0 = dr * ith;
6769
+ const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out;
6770
+
6771
+ for (int64_t oc = 0; oc < OC; oc++) {
6772
+ for (int64_t t_out = it0; t_out < it1; t_out++) {
6773
+ const int64_t t_abs = t_out + p0; // absolute position in uncropped signal
6774
+ // Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K
6775
+ int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s)
6776
+ if (t_in_min < 0) t_in_min = 0;
6777
+ int64_t t_in_max = t_abs / s0;
6778
+ if (t_in_max >= T_in) t_in_max = T_in - 1;
6779
+
6780
+ float sum = 0.0f;
6781
+ for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
6782
+ int64_t k = t_abs - t_in * s0;
6783
+ if (k >= 0 && k < K) {
6784
+ // col layout: [K*OC, T_in], element (oc*K+k, t_in)
6785
+ sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]);
6786
+ }
6787
+ }
6788
+ // dst layout: [T_out, OC], element (t_out, oc)
6789
+ dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum);
6790
+ }
6791
+ }
6792
+ }
6793
+
6794
+ void ggml_compute_forward_col2im_1d(
6795
+ const ggml_compute_params * params,
6796
+ ggml_tensor * dst) {
6797
+ switch (dst->src[0]->type) {
6798
+ case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break;
6799
+ case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break;
6800
+ case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break;
6801
+ default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type);
6802
+ }
6803
+ }
6804
+
6561
6805
  // ggml_compute_forward_conv_2d
6562
6806
 
6563
6807
 
@@ -6838,16 +7082,15 @@ void ggml_compute_forward_conv_3d(
6838
7082
  ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6839
7083
  }
6840
7084
 
6841
- // ggml_compute_forward_conv_transpose_2d
6842
-
6843
- void ggml_compute_forward_conv_transpose_2d(
6844
- const ggml_compute_params * params,
6845
- ggml_tensor * dst) {
7085
+ template <typename kernel_t>
7086
+ static void ggml_compute_forward_conv_transpose_2d_impl(
7087
+ const ggml_compute_params * params,
7088
+ ggml_tensor * dst) {
6846
7089
 
6847
7090
  const ggml_tensor * src0 = dst->src[0];
6848
7091
  const ggml_tensor * src1 = dst->src[1];
6849
7092
 
6850
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
7093
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
6851
7094
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6852
7095
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
6853
7096
 
@@ -6858,7 +7101,7 @@ void ggml_compute_forward_conv_transpose_2d(
6858
7101
 
6859
7102
  const int nk = ne00*ne01*ne02*ne03;
6860
7103
 
6861
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
7104
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
6862
7105
  GGML_ASSERT(nb10 == sizeof(float));
6863
7106
 
6864
7107
  if (ith == 0) {
@@ -6866,12 +7109,12 @@ void ggml_compute_forward_conv_transpose_2d(
6866
7109
 
6867
7110
  // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
6868
7111
  {
6869
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7112
+ kernel_t * const wdata = (kernel_t *) params->wdata + 0;
6870
7113
 
6871
7114
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6872
7115
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6873
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
6874
- ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
7116
+ const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02);
7117
+ kernel_t * dst_data = wdata + i02*ne01*ne00*ne03;
6875
7118
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6876
7119
  for (int64_t i00 = 0; i00 < ne00; i00++) {
6877
7120
  dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
@@ -6883,13 +7126,17 @@ void ggml_compute_forward_conv_transpose_2d(
6883
7126
 
6884
7127
  // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
6885
7128
  {
6886
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
7129
+ kernel_t * const wdata = (kernel_t *) params->wdata + nk;
6887
7130
  for (int i12 = 0; i12 < ne12; i12++) {
6888
7131
  for (int i11 = 0; i11 < ne11; i11++) {
6889
7132
  const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6890
- ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
7133
+ kernel_t * dst_data = wdata + i11*ne10*ne12;
6891
7134
  for (int i10 = 0; i10 < ne10; i10++) {
6892
- dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
7135
+ if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
7136
+ dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
7137
+ } else {
7138
+ dst_data[i10*ne12 + i12] = src[i10];
7139
+ }
6893
7140
  }
6894
7141
  }
6895
7142
  }
@@ -6911,21 +7158,27 @@ void ggml_compute_forward_conv_transpose_2d(
6911
7158
  const int ip0 = dp*ith;
6912
7159
  const int ip1 = MIN(ip0 + dp, np);
6913
7160
 
6914
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6915
- ggml_fp16_t * const wdata_src = wdata + nk;
7161
+ kernel_t * const wdata = (kernel_t *) params->wdata + 0;
7162
+ kernel_t * const wdata_src = wdata + nk;
6916
7163
 
6917
7164
  for (int i2 = ip0; i2 < ip1; i2++) { // Cout
6918
7165
  float * dst_data = (float *)((char *) dst->data + i2*nb2);
6919
- ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
7166
+ kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
6920
7167
  for (int i11 = 0; i11 < ne11; i11++) {
6921
7168
  for (int i10 = 0; i10 < ne10; i10++) {
6922
7169
  const int i1n = i11*ne10*ne12 + i10*ne12;
6923
7170
  for (int i01 = 0; i01 < ne01; i01++) {
6924
7171
  for (int i00 = 0; i00 < ne00; i00++) {
6925
7172
  float v = 0;
6926
- ggml_vec_dot_f16(ne03, &v, 0,
6927
- wdata_src + i1n, 0,
6928
- wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
7173
+ if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
7174
+ ggml_vec_dot_f16(ne03, &v, 0,
7175
+ wdata_src + i1n, 0,
7176
+ wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
7177
+ } else {
7178
+ ggml_vec_dot_f32(ne03, &v, 0,
7179
+ wdata_src + i1n, 0,
7180
+ wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
7181
+ }
6929
7182
  dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
6930
7183
  }
6931
7184
  }
@@ -6934,19 +7187,41 @@ void ggml_compute_forward_conv_transpose_2d(
6934
7187
  }
6935
7188
  }
6936
7189
 
6937
- // ggml_compute_forward_conv_2d_dw
7190
+ void ggml_compute_forward_conv_transpose_2d(
7191
+ const ggml_compute_params * params,
7192
+ ggml_tensor * dst) {
6938
7193
 
6939
- struct ggml_conv_2d_dw_params {
6940
- int64_t channels;
6941
- int64_t batch;
6942
- int64_t src_w;
6943
- int64_t src_h;
6944
- int64_t dst_w;
6945
- int64_t dst_h;
6946
- int64_t knl_w;
6947
- int64_t knl_h;
6948
- int stride_x;
6949
- int stride_y;
7194
+ const ggml_tensor * src0 = dst->src[0];
7195
+
7196
+ switch (src0->type) {
7197
+ case GGML_TYPE_F16:
7198
+ {
7199
+ ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst);
7200
+ } break;
7201
+ case GGML_TYPE_F32:
7202
+ {
7203
+ ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst);
7204
+ } break;
7205
+ default:
7206
+ {
7207
+ GGML_ABORT("fatal error");
7208
+ }
7209
+ }
7210
+ }
7211
+
7212
+ // ggml_compute_forward_conv_2d_dw
7213
+
7214
+ struct ggml_conv_2d_dw_params {
7215
+ int64_t channels;
7216
+ int64_t batch;
7217
+ int64_t src_w;
7218
+ int64_t src_h;
7219
+ int64_t dst_w;
7220
+ int64_t dst_h;
7221
+ int64_t knl_w;
7222
+ int64_t knl_h;
7223
+ int stride_x;
7224
+ int stride_y;
6950
7225
  int pad_x;
6951
7226
  int pad_y;
6952
7227
  int dilation_x;
@@ -7110,12 +7385,13 @@ void ggml_compute_forward_conv_2d_dw(
7110
7385
  }
7111
7386
  }
7112
7387
 
7113
- // ggml_compute_forward_pool_1d_sk_p0
7114
-
7115
- static void ggml_compute_forward_pool_1d_sk_p0(
7388
+ // ggml_compute_forward_pool_1d_ksp
7389
+ static void ggml_compute_forward_pool_1d_ksp(
7116
7390
  const ggml_compute_params * params,
7117
7391
  const ggml_op_pool op,
7118
7392
  const int k,
7393
+ const int s,
7394
+ const int p,
7119
7395
  ggml_tensor * dst) {
7120
7396
 
7121
7397
  const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7402,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
7126
7402
  return;
7127
7403
  }
7128
7404
 
7129
- const char * cdata = (const char *)src->data;
7130
- const char * const data_end = cdata + ggml_nbytes(src);
7131
- float * drow = (float *)dst->data;
7405
+ const int64_t IW = src->ne[0];
7406
+ const int64_t OW = dst->ne[0];
7132
7407
 
7133
- const int64_t rs = dst->ne[0];
7408
+ const int64_t nr = ggml_nrows(src);
7134
7409
 
7135
- while (cdata < data_end) {
7136
- const void * srow = (const void *)cdata;
7137
- int j = 0;
7138
- for (int64_t i = 0; i < rs; ++i) {
7410
+ for (int64_t ir = 0; ir < nr; ++ir) {
7411
+ const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
7412
+ float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
7413
+
7414
+ for (int64_t ow = 0; ow < OW; ++ow) {
7415
+ float res = 0;
7139
7416
  switch (op) {
7140
- case GGML_OP_POOL_AVG: drow[i] = 0; break;
7141
- case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
7417
+ case GGML_OP_POOL_AVG: res = 0.0f; break;
7418
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7142
7419
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7143
7420
  }
7421
+
7422
+ int count = 0;
7423
+ const int base = (int) ow * s - p;
7424
+
7144
7425
  for (int ki = 0; ki < k; ++ki) {
7145
- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7426
+ const int j = base + ki;
7427
+ if (j < 0 || j >= (int) IW) {
7428
+ continue;
7429
+ }
7430
+
7431
+ float v;
7432
+ if (src->type == GGML_TYPE_F32) {
7433
+ v = ((const float *) srow_bytes)[j];
7434
+ } else {
7435
+ v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7436
+ }
7437
+
7146
7438
  switch (op) {
7147
- case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
7148
- case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
7149
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7439
+ case GGML_OP_POOL_AVG: res += v; break;
7440
+ case GGML_OP_POOL_MAX: res = std::max(v, res); break;
7441
+ case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7150
7442
  }
7151
- ++j;
7443
+
7444
+ ++count;
7152
7445
  }
7446
+
7153
7447
  switch (op) {
7154
- case GGML_OP_POOL_AVG: drow[i] /= k; break;
7155
- case GGML_OP_POOL_MAX: break;
7448
+ case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7449
+ case GGML_OP_POOL_MAX: break;
7156
7450
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7157
7451
  }
7158
- }
7159
7452
 
7160
- cdata += src->nb[1];
7161
- drow += rs;
7453
+ drow[ow] = res;
7454
+ }
7162
7455
  }
7163
7456
  }
7164
7457
 
@@ -7173,10 +7466,8 @@ void ggml_compute_forward_pool_1d(
7173
7466
  const int k0 = opts[1];
7174
7467
  const int s0 = opts[2];
7175
7468
  const int p0 = opts[3];
7176
- GGML_ASSERT(p0 == 0); // padding not supported
7177
- GGML_ASSERT(k0 == s0); // only s = k supported
7178
7469
 
7179
- ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
7470
+ ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
7180
7471
  }
7181
7472
 
7182
7473
  // ggml_compute_forward_pool_2d
@@ -7194,6 +7485,7 @@ void ggml_compute_forward_pool_2d(
7194
7485
  }
7195
7486
 
7196
7487
  const int32_t * opts = (const int32_t *)dst->op_params;
7488
+
7197
7489
  ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7198
7490
  const int k0 = opts[1];
7199
7491
  const int k1 = opts[2];
@@ -7217,11 +7509,13 @@ void ggml_compute_forward_pool_2d(
7217
7509
  while (cdata < data_end) {
7218
7510
  for (int oy = 0; oy < py; ++oy) {
7219
7511
  float * const drow = dplane + oy * px;
7512
+ float * const out = drow;
7513
+
7220
7514
  for (int ox = 0; ox < px; ++ox) {
7221
- float * const out = drow + ox;
7515
+ float res = 0;
7222
7516
  switch (op) {
7223
- case GGML_OP_POOL_AVG: *out = 0; break;
7224
- case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
7517
+ case GGML_OP_POOL_AVG: res = 0; break;
7518
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7225
7519
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7226
7520
  }
7227
7521
 
@@ -7229,24 +7523,32 @@ void ggml_compute_forward_pool_2d(
7229
7523
  const int iy = offset1 + oy * s1;
7230
7524
 
7231
7525
  for (int ky = 0; ky < k1; ++ky) {
7232
- if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
7526
+ if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7527
+ continue;
7528
+ }
7529
+
7233
7530
  const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7234
7531
  for (int kx = 0; kx < k0; ++kx) {
7235
7532
  int j = ix + kx;
7236
- if (j < 0 || j >= src->ne[0]) continue;
7533
+ if (j < 0 || j >= src->ne[0]) {
7534
+ continue;
7535
+ }
7536
+
7237
7537
  const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7238
7538
  switch (op) {
7239
- case GGML_OP_POOL_AVG: *out += srow_j; break;
7240
- case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
7539
+ case GGML_OP_POOL_AVG: res += srow_j; break;
7540
+ case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
7241
7541
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7242
7542
  }
7243
7543
  }
7244
7544
  }
7245
7545
  switch (op) {
7246
- case GGML_OP_POOL_AVG: *out /= ka; break;
7247
- case GGML_OP_POOL_MAX: break;
7546
+ case GGML_OP_POOL_AVG: res /= ka; break;
7547
+ case GGML_OP_POOL_MAX: break;
7248
7548
  case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7249
7549
  }
7550
+
7551
+ out[ox] = res;
7250
7552
  }
7251
7553
  }
7252
7554
 
@@ -7603,8 +7905,7 @@ static void ggml_compute_forward_pad_f32(
7603
7905
 
7604
7906
  const ggml_tensor * src0 = dst->src[0];
7605
7907
 
7606
- GGML_ASSERT(src0->nb[0] == sizeof(float));
7607
- GGML_ASSERT( dst->nb[0] == sizeof(float));
7908
+ assert(dst->nb[0] == sizeof(float));
7608
7909
 
7609
7910
  const int ith = params->ith;
7610
7911
  const int nth = params->nth;
@@ -8016,12 +8317,14 @@ void ggml_compute_forward_top_k(
8016
8317
  }
8017
8318
  }
8018
8319
 
8019
- // ggml_compute_forward_flash_attn_ext
8020
-
8021
8320
  static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8022
8321
  const ggml_compute_params * params,
8023
8322
  ggml_tensor * dst,
8024
- int ir0, int ir1) {
8323
+ int ir0, int ir1,
8324
+ int64_t ic_start, int64_t ic_end,
8325
+ float * partials, int64_t partial_stride) {
8326
+
8327
+ const bool write_partials = (partials != nullptr);
8025
8328
  const ggml_tensor * q = dst->src[0];
8026
8329
  const ggml_tensor * k = dst->src[1];
8027
8330
  const ggml_tensor * v = dst->src[2];
@@ -8098,7 +8401,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8098
8401
 
8099
8402
  int ith = params->ith;
8100
8403
 
8101
- // loop over n_batch and n_head
8102
8404
  for (int ir = ir0; ir < ir1; ++ir) {
8103
8405
  // q indices
8104
8406
  const int iq3 = ir/(neq2*neq1);
@@ -8138,7 +8440,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8138
8440
  // online softmax / attention
8139
8441
  // loop over n_kv and n_head_kv
8140
8442
  // ref: https://arxiv.org/pdf/2112.05682.pdf
8141
- for (int64_t ic = 0; ic < nek1; ++ic) {
8443
+
8444
+ for (int64_t ic = ic_start; ic < ic_end; ++ic) {
8142
8445
  const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8143
8446
  if (mv == -INFINITY) {
8144
8447
  continue;
@@ -8211,8 +8514,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8211
8514
  }
8212
8515
  }
8213
8516
 
8214
- // sinks
8215
- if (sinks) {
8517
+ // sinks - apply only on the first kv-chunk
8518
+ if (sinks && ic_start == 0) {
8216
8519
  const float s = ((float *)((char *) sinks->data))[h];
8217
8520
 
8218
8521
  float ms = 1.0f;
@@ -8220,6 +8523,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8220
8523
 
8221
8524
  if (s > M) {
8222
8525
  ms = expf(M - s);
8526
+ M = s;
8223
8527
  ggml_vec_scale_f32(DV, VKQ32, ms);
8224
8528
  } else {
8225
8529
  vs = expf(s - M);
@@ -8228,20 +8532,386 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8228
8532
  S = S*ms + vs;
8229
8533
  }
8230
8534
 
8231
- // V /= S
8232
- const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8233
- ggml_vec_scale_f32(DV, VKQ32, S_inv);
8535
+ if (write_partials) {
8536
+ // Write M, S, VKQ to partials for later reduction
8537
+ // partials layout: [M, S, VKQ[DV]] per query head
8538
+ float * partial = partials + ir * partial_stride;
8539
+ partial[0] = M;
8540
+ partial[1] = S;
8541
+ memcpy(partial + 2, VKQ32, DV * sizeof(float));
8542
+ } else {
8543
+ // V /= S
8544
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8545
+ ggml_vec_scale_f32(DV, VKQ32, S_inv);
8234
8546
 
8235
- // dst indices
8236
- const int i1 = iq1;
8237
- const int i2 = iq2;
8238
- const int i3 = iq3;
8547
+ // dst indices
8548
+ const int i1 = iq1;
8549
+ const int i2 = iq2;
8550
+ const int i3 = iq3;
8551
+
8552
+ // permute(0, 2, 1, 3)
8553
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8554
+ }
8555
+ }
8556
+ }
8557
+
8558
+ static void ggml_compute_forward_flash_attn_ext_tiled(
8559
+ const ggml_compute_params * params,
8560
+ ggml_tensor * dst,
8561
+ int ir0, int ir1) {
8562
+ const ggml_tensor * q = dst->src[0];
8563
+ const ggml_tensor * k = dst->src[1];
8564
+ const ggml_tensor * v = dst->src[2];
8565
+ const ggml_tensor * mask = dst->src[3];
8566
+ const ggml_tensor * sinks = dst->src[4];
8567
+
8568
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8569
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8570
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8571
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8572
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8573
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8574
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8575
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8576
+
8577
+ const int64_t DK = nek0;
8578
+ const int64_t DV = nev0;
8579
+ const int64_t N = neq1;
8580
+
8581
+ GGML_ASSERT(ne0 == DV);
8582
+ GGML_ASSERT(ne2 == N);
8583
+
8584
+ // input tensor rows must be contiguous
8585
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8586
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8587
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8588
+
8589
+ GGML_ASSERT(neq0 == DK);
8590
+ GGML_ASSERT(nek0 == DK);
8591
+ GGML_ASSERT(nev0 == DV);
8592
+
8593
+ GGML_ASSERT(neq1 == N);
8594
+
8595
+ // dst cannot be transposed or permuted
8596
+ GGML_ASSERT(nb0 == sizeof(float));
8597
+ GGML_ASSERT(nb0 <= nb1);
8598
+ GGML_ASSERT(nb1 <= nb2);
8599
+ GGML_ASSERT(nb2 <= nb3);
8600
+
8601
+ GGML_ASSERT(k->type == v->type);
8602
+ const ggml_type kv_type = k->type;
8603
+
8604
+
8605
+ // broadcast factors
8606
+ const int64_t rk2 = neq2/nek2;
8607
+ const int64_t rk3 = neq3/nek3;
8608
+
8609
+ const int64_t rv2 = neq2/nev2;
8610
+ const int64_t rv3 = neq3/nev3;
8611
+
8612
+ float scale = 1.0f;
8613
+ float max_bias = 0.0f;
8614
+ float logit_softcap = 0.0f;
8615
+
8616
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
8617
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
8618
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8619
+
8620
+ if (logit_softcap != 0) {
8621
+ scale /= logit_softcap;
8622
+ }
8623
+
8624
+ const uint32_t n_head = neq2;
8625
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8626
+
8627
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8628
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8629
+
8630
+ int ith = params->ith;
8631
+
8632
+ static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
8633
+ static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
8634
+
8635
+ int ir = ir0;
8636
+ while (ir < ir1) {
8637
+ // q indices for the start of this tile
8638
+ const int iq3 = ir/(neq2*neq1);
8639
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8640
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8641
+
8642
+ // Number of valid rows in this tile:
8643
+ // - limited by tile size (Q_TILE_SZ)
8644
+ // - limited by chunk boundary (ir1 - ir)
8645
+ // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
8646
+ const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
8647
+ GGML_ASSERT(tile_rows > 0);
8648
+
8649
+ const uint32_t h = iq2; // head index
8650
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8651
+
8652
+ float S[Q_TILE_SZ];
8653
+ float M[Q_TILE_SZ];
8654
+
8655
+ for (int i = 0 ; i < Q_TILE_SZ; ++i) {
8656
+ S[i] = 0.;
8657
+ M[i] = -INFINITY;
8658
+ }
8659
+
8660
+ // Per-thread scratch layout:
8661
+ // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
8662
+ // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
8663
+ // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
8664
+ // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8665
+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile)
8666
+ // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8667
+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
8668
+
8669
+ void * Q_q = base;
8670
+ float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
8671
+ float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
8672
+ float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8673
+ float * V32 = VKQ32 + Q_TILE_SZ * DV;
8674
+ float * K_f32 = V32 + KV_TILE_SZ * DV;
8675
+
8676
+ memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
8677
+ memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8678
+
8679
+ // k indices
8680
+ const int ik3 = iq3 / rk3;
8681
+ const int ik2 = iq2 / rk2;
8682
+
8683
+ // v indices
8684
+ const int iv3 = iq3 / rv3;
8685
+ const int iv2 = iq2 / rv2;
8686
+
8687
+ {
8688
+ float * Q_f32 = (float *)Q_q;
8689
+ for (int tq = 0; tq < tile_rows; tq++) {
8690
+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8691
+ memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
8692
+ }
8693
+ for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8694
+ memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
8695
+ }
8696
+ }
8697
+
8698
+ memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
8699
+ memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
8700
+
8701
+ for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
8702
+ const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
8703
+
8704
+ // skip the tile entirely if all the masks are -inf
8705
+ if (mask) {
8706
+ bool can_skip = true;
8707
+ for (int tq = 0; tq < tile_rows; tq++) {
8708
+ const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
8709
+ for (int tk = 0; tk < kv_tile; tk++) {
8710
+ mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
8711
+ if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
8712
+ can_skip = false;
8713
+ }
8714
+ }
8715
+ // Pad remaining mask entries with -inf
8716
+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8717
+ mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
8718
+ }
8719
+ }
8720
+
8721
+ if (can_skip) {
8722
+ continue;
8723
+ }
8724
+ }
8725
+
8726
+ // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8727
+ // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8728
+ for (int tk = 0; tk < kv_tile; tk++) {
8729
+ const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8730
+ if (kv_type == GGML_TYPE_F16) {
8731
+ const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8732
+ for (int64_t dk = 0; dk < DK; dk++) {
8733
+ K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
8734
+ }
8735
+ } else {
8736
+ const float * k_f32_src = (const float *)k_data;
8737
+ for (int64_t dk = 0; dk < DK; dk++) {
8738
+ K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8739
+ }
8740
+ }
8741
+ }
8742
+ memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8743
+ simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
8744
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
8745
+
8746
+ // Set padded KQ entries to -inf so softmax gives them zero weight
8747
+ if (kv_tile < KV_TILE_SZ) {
8748
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8749
+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8750
+ KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
8751
+ }
8752
+ }
8753
+ }
8754
+
8755
+ if (logit_softcap != 0.0f) {
8756
+ ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
8757
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
8758
+ }
8759
+
8760
+ if (mask) {
8761
+ ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
8762
+ }
8763
+
8764
+ bool skip[Q_TILE_SZ] = {};
8765
+
8766
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8767
+ float * kq_row = KQ + tq * KV_TILE_SZ;
8768
+
8769
+ float tile_max;
8770
+ ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
8771
+
8772
+ if (tile_max == -INFINITY) {
8773
+ skip[tq] = true;
8774
+ continue;
8775
+ }
8776
+
8777
+ const float Mold = M[tq];
8778
+ const float Mnew = fmaxf(Mold, tile_max);
8779
+
8780
+ if (Mnew > Mold) {
8781
+ const float ms = expf(Mold - Mnew);
8782
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8783
+ S[tq] *= ms;
8784
+ }
8785
+ M[tq] = Mnew;
8786
+
8787
+
8788
+ S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
8789
+ }
8790
+
8791
+ // V accumulation: VKQ32 += softmax(KQ) * V
8792
+ // Pack V tile to contiguous F32, zero-padded
8793
+ for (int tk = 0; tk < kv_tile; tk++) {
8794
+ const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8795
+ if (kv_type == GGML_TYPE_F16) {
8796
+ ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
8797
+ } else {
8798
+ memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
8799
+ }
8800
+ }
8801
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8802
+ if (skip[tq]) {
8803
+ memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
8804
+ }
8805
+ }
8806
+ simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
8807
+ }
8808
+
8809
+ // sinks (apply only to valid rows in the tile)
8810
+ if (sinks) {
8811
+ const float s = ((float *)((char *) sinks->data))[h];
8812
+
8813
+ for (int tq = 0; tq < tile_rows; tq++) {
8814
+ float ms = 1.0f;
8815
+ float vs = 1.0f;
8816
+
8817
+ if (s > M[tq]) {
8818
+ ms = expf(M[tq] - s);
8819
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8820
+ } else {
8821
+ vs = expf(s - M[tq]);
8822
+ }
8823
+
8824
+ S[tq] = S[tq] * ms + vs;
8825
+ }
8826
+ }
8827
+
8828
+ for (int tq = 0; tq < tile_rows; tq++) {
8829
+ // V /= S
8830
+ const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
8831
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
8832
+
8833
+ // dst indices
8834
+ const int i1 = iq1 + tq;
8835
+ const int i2 = iq2;
8836
+ const int i3 = iq3;
8837
+
8838
+ // permute(0, 2, 1, 3)
8839
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
8840
+ }
8841
+
8842
+ ir += tile_rows;
8843
+ }
8844
+ }
8845
+
8846
+ // Reduction function: combines partial results across KV chunks
8847
+ // Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8848
+ static void ggml_flash_attn_ext_reduce_partials(
8849
+ const ggml_compute_params * params,
8850
+ ggml_tensor * dst,
8851
+ const int64_t n_chunks,
8852
+ const int64_t chunk_size) {
8853
+
8854
+ const ggml_tensor * q = dst->src[0];
8855
+ const ggml_tensor * k = dst->src[1];
8856
+ const ggml_tensor * v = dst->src[2];
8857
+
8858
+ const int64_t DK = k->ne[0];
8859
+ const int64_t DV = v->ne[0];
8860
+ const int64_t nek1 = k->ne[1];
8861
+ const int64_t n_q_heads = q->ne[2];
8862
+
8863
+ const int ith = params->ith;
8864
+ const int nth = params->nth;
8865
+
8866
+ const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8867
+ float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8868
+
8869
+ const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8870
+ const int64_t partial_size = 2 + DV;
8871
+ const float * partials_base = (const float *) params->wdata + partials_offset;
8872
+
8873
+ // Output layout
8874
+ const int64_t ne1 = dst->ne[1];
8875
+ const int64_t ne2 = dst->ne[2];
8876
+ const size_t nb1 = dst->nb[1];
8239
8877
 
8240
- // original
8241
- //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8878
+ // Each thread reduces a subset of query heads
8879
+ for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8880
+ float M_final = -INFINITY;
8881
+ float S_final = 0.0f;
8882
+ float * VKQ_final = thread_wdata;
8883
+ memset(VKQ_final, 0, DV * sizeof(float));
8242
8884
 
8243
- // permute(0, 2, 1, 3)
8244
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8885
+ // Combine partials from all chunks
8886
+ for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
8887
+ const int64_t ic_start = chunk_idx * chunk_size;
8888
+ if (ic_start >= nek1) continue;
8889
+
8890
+ const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8891
+ const float M_chunk = partial[0];
8892
+ const float S_chunk = partial[1];
8893
+ const float * VKQ_chunk = partial + 2;
8894
+
8895
+ if (S_chunk == 0.0f) continue;
8896
+
8897
+ const float M_new = fmaxf(M_final, M_chunk);
8898
+ const float scale_old = expf(M_final - M_new);
8899
+ const float scale_new = expf(M_chunk - M_new);
8900
+
8901
+ for (int64_t d = 0; d < DV; ++d) {
8902
+ VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8903
+ }
8904
+ S_final = S_final * scale_old + S_chunk * scale_new;
8905
+ M_final = M_new;
8906
+ }
8907
+
8908
+ // Normalize and write to output
8909
+ if (S_final != 0.0f) {
8910
+ const float S_inv = 1.0f / S_final;
8911
+ ggml_vec_scale_f32(DV, VKQ_final, S_inv);
8912
+ }
8913
+ // iq1=0, iq3=0 for decode
8914
+ memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
8245
8915
  }
8246
8916
  }
8247
8917
 
@@ -8266,6 +8936,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8266
8936
  const int64_t DV = nev0;
8267
8937
  const int64_t N = neq1;
8268
8938
 
8939
+
8269
8940
  GGML_ASSERT(ne0 == DV);
8270
8941
  GGML_ASSERT(ne2 == N);
8271
8942
 
@@ -8286,47 +8957,97 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8286
8957
  GGML_ASSERT(nb1 <= nb2);
8287
8958
  GGML_ASSERT(nb2 <= nb3);
8288
8959
 
8289
- // parallelize by q rows using ggml_vec_dot_f32
8290
-
8291
- // total rows in q
8292
- const int64_t nr = neq1*neq2*neq3;
8293
-
8294
- // rows per thread
8295
8960
  const int ith = params->ith;
8296
8961
  const int nth = params->nth;
8297
8962
 
8298
- // disable for NUMA
8299
- const bool disable_chunking = ggml_is_numa();
8963
+ // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8964
+ const bool use_ref = params->use_ref;
8300
8965
 
8301
- // 4x chunks per thread
8302
- int nth_scaled = nth * 4;
8303
- int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8304
- int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8966
+ const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8967
+ const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
8305
8968
 
8306
- if (nth == 1 || nchunk < nth || disable_chunking) {
8307
- nchunk = nth;
8308
- }
8969
+ if (use_split_kv_path) {
8970
+ const int64_t chunk_size = (nek1 + nth - 1) / nth;
8309
8971
 
8310
- if (ith == 0) {
8311
- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8312
- ggml_threadpool_chunk_set(params->threadpool, nth);
8313
- }
8972
+ // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8973
+ const int64_t partial_size = 2 + DV;
8974
+ float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8314
8975
 
8315
- ggml_barrier(params->threadpool);
8976
+ const int64_t ic_start = ith * chunk_size;
8977
+ const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
8316
8978
 
8317
- // The number of elements in each chunk
8318
- const int64_t dr = (nr + nchunk - 1) / nchunk;
8979
+ const int64_t partial_stride = nth * partial_size;
8980
+ float * chunk_partials = partials_base + ith * partial_size;
8319
8981
 
8320
- // The first chunk comes from our thread_id, the rest will get auto-assigned.
8321
- int current_chunk = ith;
8982
+ if (ic_start < nek1) {
8983
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
8984
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8985
+ params, dst, q_head, q_head + 1, ic_start, ic_end,
8986
+ chunk_partials, partial_stride);
8987
+ }
8988
+ } else {
8989
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
8990
+ float * q_partials = chunk_partials + q_head * partial_stride;
8991
+ q_partials[0] = -INFINITY; // M
8992
+ q_partials[1] = 0.0f; // S
8993
+ }
8994
+ }
8322
8995
 
8323
- while (current_chunk < nchunk) {
8324
- const int64_t ir0 = dr * current_chunk;
8325
- const int64_t ir1 = MIN(ir0 + dr, nr);
8996
+ ggml_barrier(params->threadpool);
8997
+ ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
8998
+ } else {
8326
8999
 
8327
- ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
9000
+ // total rows in q
9001
+ const int64_t nr = neq1*neq2*neq3;
8328
9002
 
8329
- current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
9003
+ // disable for NUMA
9004
+ const bool disable_chunking = ggml_is_numa();
9005
+
9006
+ // 4x chunks per thread
9007
+ int nth_scaled = nth * 4;
9008
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
9009
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
9010
+
9011
+ if (nth == 1 || nchunk < nth || disable_chunking) {
9012
+ nchunk = nth;
9013
+ }
9014
+
9015
+ if (ith == 0) {
9016
+ ggml_threadpool_chunk_set(params->threadpool, nth);
9017
+ }
9018
+
9019
+ ggml_barrier(params->threadpool);
9020
+
9021
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
9022
+
9023
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
9024
+ bool use_tiled = !use_ref &&
9025
+ (q->type == GGML_TYPE_F32 &&
9026
+ kv_is_f32_or_f16 &&
9027
+ k->type == v->type &&
9028
+ neq1 >= Q_TILE_SZ);
9029
+ #ifdef GGML_SIMD
9030
+ #if defined(__ARM_FEATURE_SVE)
9031
+ const int64_t f32_epr = svcntw();
9032
+ #else
9033
+ const int64_t f32_epr = GGML_F32_EPR;
9034
+ #endif
9035
+ use_tiled &= (DV % f32_epr == 0);
9036
+ #endif
9037
+ int current_chunk = ith;
9038
+
9039
+ while (current_chunk < nchunk) {
9040
+ const int64_t ir0 = dr * current_chunk;
9041
+ const int64_t ir1 = MIN(ir0 + dr, nr);
9042
+
9043
+ if (use_tiled) {
9044
+ ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
9045
+ } else {
9046
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
9047
+ }
9048
+
9049
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
9050
+ }
8330
9051
  }
8331
9052
  }
8332
9053
 
@@ -9107,7 +9828,7 @@ void ggml_compute_forward_win_unpart(
9107
9828
  }
9108
9829
  }
9109
9830
 
9110
- //gmml_compute_forward_unary
9831
+ //ggml_compute_forward_unary
9111
9832
 
9112
9833
  void ggml_compute_forward_unary(
9113
9834
  const ggml_compute_params * params,
@@ -9396,13 +10117,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
9396
10117
  const int ith = params->ith;
9397
10118
  const int nth = params->nth;
9398
10119
 
9399
- if (ith >= HEADS) {
9400
- return;
9401
- }
9402
-
9403
- const int h_start = (HEADS * ith) / nth;
9404
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9405
- (HEADS * (ith + 1)) / nth : HEADS;
10120
+ const int h_start = (HEADS * (ith )) / nth;
10121
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10122
+ (HEADS * (ith + 1)) / nth : HEADS;
9406
10123
 
9407
10124
  float * k = (float *) dst->src[0]->data;
9408
10125
  float * v = (float *) dst->src[1]->data;
@@ -9613,13 +10330,9 @@ static void ggml_compute_forward_gla_f32(
9613
10330
  const int ith = params->ith;
9614
10331
  const int nth = params->nth;
9615
10332
 
9616
- if (ith >= HEADS) {
9617
- return;
9618
- }
9619
-
9620
- const int h_start = (HEADS * ith) / nth;
9621
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9622
- (HEADS * (ith + 1)) / nth : HEADS;
10333
+ const int h_start = (HEADS * (ith )) / nth;
10334
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10335
+ (HEADS * (ith + 1)) / nth : HEADS;
9623
10336
 
9624
10337
  float * k = (float *) dst->src[0]->data;
9625
10338
  float * v = (float *) dst->src[1]->data;
@@ -9870,6 +10583,219 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
9870
10583
  }
9871
10584
  }
9872
10585
 
10586
+ // ggml_compute_forward_gated_delta_net
10587
+ static void ggml_compute_forward_gated_delta_net_one_chunk(
10588
+ const ggml_compute_params * params,
10589
+ ggml_tensor * dst,
10590
+ int64_t ir0,
10591
+ int64_t ir1) {
10592
+
10593
+ ggml_tensor * src_q = dst->src[0];
10594
+ ggml_tensor * src_k = dst->src[1];
10595
+ ggml_tensor * src_v = dst->src[2];
10596
+ ggml_tensor * src_g = dst->src[3];
10597
+ ggml_tensor * src_beta = dst->src[4];
10598
+ ggml_tensor * src_state = dst->src[5];
10599
+
10600
+ const int64_t S_v = src_v->ne[0];
10601
+ const int64_t H = src_v->ne[1];
10602
+ const int64_t n_tokens = src_v->ne[2];
10603
+ const int64_t n_seqs = src_v->ne[3];
10604
+
10605
+ GGML_ASSERT(ggml_is_contiguous_rows(src_q));
10606
+ GGML_ASSERT(ggml_is_contiguous_rows(src_k));
10607
+ GGML_ASSERT(ggml_is_contiguous_rows(src_v));
10608
+ GGML_ASSERT(ggml_is_contiguous(src_g));
10609
+ GGML_ASSERT(ggml_is_contiguous(src_beta));
10610
+ GGML_ASSERT(ggml_is_contiguous(src_state));
10611
+
10612
+ GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
10613
+ GGML_ASSERT(src_beta->ne[0] == 1);
10614
+
10615
+ GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
10616
+ GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
10617
+ GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
10618
+ GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
10619
+ GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
10620
+ GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
10621
+ GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
10622
+ GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
10623
+ GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
10624
+
10625
+ const bool kda = (neg0 == S_v);
10626
+
10627
+ // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
10628
+ const int64_t K = ggml_get_op_params_i32(dst, 0);
10629
+ GGML_ASSERT(K >= 1);
10630
+ // per-seq stride in floats (seq s starts at state + s * seq_stride)
10631
+ const int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
10632
+
10633
+ const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
10634
+ const int ith = params->ith;
10635
+
10636
+ float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
10637
+ float * state_work = K > 1 ? (delta + S_v) : nullptr;
10638
+
10639
+ // output layout: [attn_scores | new_states]
10640
+ // attn_scores: S_v * H * n_tokens * n_seqs floats
10641
+ // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K))
10642
+ const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
10643
+ const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
10644
+ float * attn_out_base = (float *)dst->data;
10645
+ float * state_out_base = (float *)dst->data + attn_score_elems;
10646
+
10647
+ // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
10648
+ // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
10649
+
10650
+ const float * state_in_base = (const float *)src_state->data;
10651
+
10652
+ //const int64_t rq1 = nev1 / neq1;
10653
+ //const int64_t rk1 = nev1 / nek1;
10654
+ const int64_t rq3 = nev3 / neq3;
10655
+ const int64_t rk3 = nev3 / nek3;
10656
+
10657
+ const float scale = 1.0f / sqrtf((float) S_v);
10658
+
10659
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10660
+ const int64_t iv1 = ir % H; // head_index
10661
+ const int64_t iv3 = ir / H; // sequence
10662
+
10663
+ const int64_t iq1 = iv1 % neq1;
10664
+ const int64_t ik1 = iv1 % nek1;
10665
+
10666
+ const int64_t iq3 = iv3 / rq3;
10667
+ const int64_t ik3 = iv3 / rk3;
10668
+
10669
+ // For K=1, write directly to the single output slot to avoid an extra memcpy at the end.
10670
+ // For K>1, work in scratch and copy out per-token when the slot is in range.
10671
+ float * s_out = (K > 1)
10672
+ ? state_work
10673
+ : state_out_base + (iv3 * H + iv1) * S_v * S_v;
10674
+
10675
+ // copy input state into the working buffer and operate in-place
10676
+ // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride.
10677
+ const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
10678
+ memcpy(s_out, s_in, S_v * S_v * sizeof(float));
10679
+
10680
+ // attn output pointer for first token of this (head, seq)
10681
+ float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
10682
+
10683
+ for (int64_t t = 0; t < n_tokens; t++) {
10684
+ const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
10685
+ const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
10686
+ const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
10687
+
10688
+ const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
10689
+ const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
10690
+
10691
+ // state is stored transposed: s_out[j*S_v + i] = S[i][j]
10692
+ // so row j of s_out = column j of S (contiguous access)
10693
+
10694
+ if (kda) {
10695
+ // precompute exp(g) into delta scratch (reused below)
10696
+ for (int64_t i = 0; i < S_v; ++i) {
10697
+ delta[i] = expf(g_d[i]);
10698
+ }
10699
+ // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
10700
+ for (int64_t j = 0; j < S_v; ++j) {
10701
+ ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
10702
+ }
10703
+ } else {
10704
+ ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
10705
+ }
10706
+
10707
+ // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
10708
+ for (int64_t j = 0; j < S_v; ++j) {
10709
+ float sum = 0.0f;
10710
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
10711
+ delta[j] = (v_d[j] - sum) * beta_val;
10712
+ }
10713
+
10714
+ // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
10715
+ for (int64_t j = 0; j < S_v; ++j) {
10716
+ ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
10717
+ }
10718
+
10719
+ // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
10720
+ for (int64_t j = 0; j < S_v; ++j) {
10721
+ float sum = 0.0f;
10722
+ ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
10723
+ attn_data[j] = sum * scale;
10724
+ }
10725
+
10726
+ attn_data += S_v * H; // advance to next token
10727
+
10728
+ if (K > 1) {
10729
+ const int64_t target_slot = n_tokens - 1 - t;
10730
+ if (target_slot >= 0 && target_slot < K) {
10731
+ float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
10732
+ (iv3 * H + iv1) * S_v * S_v;
10733
+ memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
10734
+ }
10735
+ }
10736
+ }
10737
+ }
10738
+ }
10739
+
10740
+
10741
+ static void ggml_compute_forward_gated_delta_net_f32(
10742
+ const ggml_compute_params * params,
10743
+ ggml_tensor * dst) {
10744
+
10745
+ ggml_tensor * V = dst->src[2];
10746
+ int64_t nr = V->ne[1] * V->ne[3];
10747
+
10748
+ // disable for NUMA
10749
+ const bool disable_chunking = ggml_is_numa();
10750
+
10751
+ int nth = params->nth;
10752
+ int ith = params->ith;
10753
+
10754
+ // 4x chunks per thread
10755
+ int nth_scaled = nth * 4;
10756
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
10757
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
10758
+
10759
+ if (nth == 1 || nchunk < nth || disable_chunking) {
10760
+ nchunk = nth;
10761
+ }
10762
+
10763
+ if (ith == 0) {
10764
+ ggml_threadpool_chunk_set(params->threadpool, nth);
10765
+ }
10766
+
10767
+ ggml_barrier(params->threadpool);
10768
+
10769
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
10770
+
10771
+ int current_chunk = ith;
10772
+
10773
+ while (current_chunk < nchunk) {
10774
+ const int64_t ir0 = dr * current_chunk;
10775
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10776
+
10777
+ ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
10778
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
10779
+ }
10780
+ }
10781
+
10782
+ void ggml_compute_forward_gated_delta_net(
10783
+ const ggml_compute_params * params,
10784
+ ggml_tensor * dst) {
10785
+ const ggml_tensor * src0 = dst->src[0];
10786
+
10787
+ switch (src0->type) {
10788
+ case GGML_TYPE_F32:
10789
+ {
10790
+ ggml_compute_forward_gated_delta_net_f32(params, dst);
10791
+ } break;
10792
+ default:
10793
+ {
10794
+ GGML_ABORT("fatal error");
10795
+ }
10796
+ }
10797
+ }
10798
+
9873
10799
  // ggml_compute_forward_rwkv_wkv7
9874
10800
 
9875
10801
  static void ggml_compute_forward_rwkv_wkv7_f32(
@@ -9887,13 +10813,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
9887
10813
  const int ith = params->ith;
9888
10814
  const int nth = params->nth;
9889
10815
 
9890
- if (ith >= HEADS) {
9891
- return;
9892
- }
9893
-
9894
- const int h_start = (HEADS * ith) / nth;
9895
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9896
- (HEADS * (ith + 1)) / nth : HEADS;
10816
+ const int h_start = (HEADS * (ith )) / nth;
10817
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10818
+ (HEADS * (ith + 1)) / nth : HEADS;
9897
10819
 
9898
10820
  float * r = (float *) dst->src[0]->data;
9899
10821
  float * w = (float *) dst->src[1]->data;
@@ -10195,7 +11117,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
10195
11117
  assert(!isnan(s0[i]));
10196
11118
  assert(!isnan(s1[i]));
10197
11119
  }
10198
- #endif
11120
+ #endif // NDEBUG
10199
11121
 
10200
11122
  float max = -INFINITY;
10201
11123
  ggml_vec_max_f32(nc, &max, s0);
@@ -10214,7 +11136,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
10214
11136
  assert(!isnan(st[i]));
10215
11137
  assert(!isinf(st[i]));
10216
11138
  }
10217
- #endif
11139
+ #endif // NDEBUG
10218
11140
  }
10219
11141
  sums[ith] = sum_thread;
10220
11142
  ggml_barrier(params->threadpool);
@@ -10287,7 +11209,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
10287
11209
  assert(!isnan(s0[i]));
10288
11210
  assert(!isnan(s1[i]));
10289
11211
  }
10290
- #endif
11212
+ #endif // NDEBUG
10291
11213
 
10292
11214
  // soft_max
10293
11215
  float max = -INFINITY;
@@ -10305,7 +11227,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
10305
11227
  assert(!isnan(ds0[i]));
10306
11228
  assert(!isinf(ds0[i]));
10307
11229
  }
10308
- #endif
11230
+ #endif // NDEBUG
10309
11231
  }
10310
11232
  }
10311
11233
 
@@ -10471,3 +11393,95 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
10471
11393
  }
10472
11394
  }
10473
11395
  }
11396
+
11397
+ static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
11398
+ const ggml_tensor * src0 = dst->src[0];
11399
+ const ggml_tensor * src1 = dst->src[1];
11400
+
11401
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
11402
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
11403
+
11404
+ GGML_TENSOR_BINARY_OP_LOCALS
11405
+
11406
+ const int ith = params->ith;
11407
+ const int nth = params->nth;
11408
+
11409
+ const int64_t n = ne10;
11410
+ GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
11411
+
11412
+ const int64_t nr = ne11 * ne12 * ne13;
11413
+ const int64_t rows_per_thread = (nr + nth - 1) / nth;
11414
+ const int64_t start_row = ith * rows_per_thread;
11415
+ const int64_t end_row = MIN(start_row + rows_per_thread, nr);
11416
+
11417
+ const float scale = 1.0f / sqrtf((float)n);
11418
+
11419
+ #if defined(GGML_SIMD)
11420
+ const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
11421
+ #endif
11422
+
11423
+ for (int64_t r = start_row; r < end_row; r++) {
11424
+ const int64_t i13 = r / (ne11 * ne12);
11425
+ const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
11426
+ const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
11427
+
11428
+ const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
11429
+ float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
11430
+
11431
+ for (int64_t j = 0; j < n; j++) {
11432
+ dst_row[j] = src_row[j] * scale;
11433
+ }
11434
+
11435
+ // Scalar passes
11436
+ #if defined(GGML_SIMD)
11437
+ #if defined(__ARM_FEATURE_SVE)
11438
+ const int step = svcntw();
11439
+ #else
11440
+ const int step = GGML_F32_EPR;
11441
+ #endif
11442
+ #else
11443
+ const int step = n;
11444
+ #endif
11445
+ for (int64_t len = 1; len < step && len < n; len <<= 1) {
11446
+ for (int64_t i = 0; i < n; i += 2 * len) {
11447
+ for (int64_t j = 0; j < len; j++) {
11448
+ float u = dst_row[i + j];
11449
+ float v = dst_row[i + len + j];
11450
+ dst_row[i + j] = u + v;
11451
+ dst_row[i + len + j] = u - v;
11452
+ }
11453
+ }
11454
+ }
11455
+
11456
+ // SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
11457
+ #if defined(GGML_SIMD)
11458
+ for (int64_t len = step; len < n; len <<= 1) {
11459
+ for (int64_t i = 0; i < n; i += 2 * len) {
11460
+ for (int64_t j = 0; j < len; j += step) {
11461
+ GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
11462
+ GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
11463
+
11464
+ GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v));
11465
+ GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
11466
+ }
11467
+ }
11468
+ }
11469
+ #endif
11470
+ }
11471
+ }
11472
+
11473
+ void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
11474
+ const ggml_tensor * src1 = dst->src[1];
11475
+
11476
+ switch (src1->type) {
11477
+ case GGML_TYPE_F32:
11478
+ {
11479
+ ggml_compute_forward_fwht_f32(params, dst);
11480
+ }
11481
+ break;
11482
+ default:
11483
+ {
11484
+ GGML_ABORT("fatal error - fwht is F32 only");
11485
+ }
11486
+ }
11487
+ }