whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -1,4 +1,5 @@
1
1
  #include "rope.hpp"
2
+ #include "convert.hpp"
2
3
  #include "ggml-sycl/common.hpp"
3
4
  #include "ggml.h"
4
5
 
@@ -15,367 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
15
16
  return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
16
17
  }
17
18
 
18
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
19
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
20
- static void rope_yarn(
21
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
22
- float * cos_theta, float * sin_theta) {
23
- // Get n-d rotational scaling corrected for extrapolation
19
+ template <bool forward>
20
+ static void rope_yarn(const float theta_extrap, const float freq_scale,
21
+ const rope_corr_dims corr_dims, const int64_t i0,
22
+ const float ext_factor, float mscale, float &cos_theta,
23
+ float &sin_theta) {
24
24
  float theta_interp = freq_scale * theta_extrap;
25
25
  float theta = theta_interp;
26
26
  if (ext_factor != 0.0f) {
27
- float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
27
+ float ramp_mix =
28
+ rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
28
29
  theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
29
30
 
30
- // Get n-d magnitude scaling corrected for interpolation
31
31
  mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
32
32
  }
33
- *cos_theta = sycl::cos(theta) * mscale;
34
- *sin_theta = sycl::sin(theta) * mscale;
33
+ cos_theta = sycl::cos(theta) * mscale;
34
+ sin_theta = sycl::sin(theta) * mscale;
35
+ if (!forward) {
36
+ sin_theta *= -1.0f;
37
+ }
35
38
  }
36
39
 
37
- template <typename T, bool has_ff>
38
- static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39
- const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41
- const sycl::nd_item<3> & item_ct1) {
42
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
43
-
44
- if (i0 >= ne0) {
40
+ template <bool forward, bool has_ff, typename T, typename D>
41
+ static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
42
+ const int ne02, const int s01, const int s02,
43
+ const int s03, const int s1, const int s2, const int s3,
44
+ const int n_dims, const int32_t *pos,
45
+ const float freq_scale, const float ext_factor,
46
+ const float attn_factor, const rope_corr_dims corr_dims,
47
+ const float theta_scale, const float *freq_factors,
48
+ const int64_t *row_indices, const int set_rows_stride) {
49
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
50
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
51
+ item_ct1.get_local_id(1));
52
+
53
+ if (i0 >= ne00) {
45
54
  return;
46
55
  }
47
56
 
48
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
57
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
58
+ item_ct1.get_local_id(2);
49
59
 
50
- const int row0 = row % ne1;
51
- const int channel0 = row / ne1;
60
+ const uint32_t i3 = row_dst / (ne01 * ne02);
61
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
62
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
52
63
 
53
- const int i = row * ne0 + i0;
54
- const int i2 = channel0 * s2 + row0 * s1 + i0;
64
+ int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
65
+ const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
66
+
67
+ if (set_rows_stride != 0) {
68
+ idst = i1 * s1 + i0;
69
+ idst += row_indices[i2] * set_rows_stride;
70
+ }
55
71
 
72
+ const auto &store_coaelsced = [&](float x0, float x1) {
73
+ if constexpr (std::is_same_v<float, D>) {
74
+ sycl::float2 v = sycl::float2(x0, x1);
75
+ ggml_sycl_memcpy_1<8>(dst + idst, &v);
76
+ } else if constexpr (std::is_same_v<sycl::half, D>) {
77
+ sycl::half2 v = sycl::half2(x0, x1);
78
+ ggml_sycl_memcpy_1<4>(dst + idst, &v);
79
+ }
80
+ };
56
81
  if (i0 >= n_dims) {
57
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
82
+ store_coaelsced(x[ix + 0], x[ix + 1]);
58
83
  return;
59
84
  }
60
85
 
61
- const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
86
+ const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
62
87
 
63
88
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
64
89
 
65
90
  float cos_theta;
66
91
  float sin_theta;
67
92
 
68
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
93
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
94
+ ext_factor, attn_factor, cos_theta, sin_theta);
69
95
 
70
- const float x0 = x[i2 + 0];
71
- const float x1 = x[i2 + 1];
96
+ const float x0 = x[ix + 0];
97
+ const float x1 = x[ix + 1];
72
98
 
73
- dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
74
- dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
99
+ store_coaelsced(x0 * cos_theta - x1 * sin_theta,
100
+ x0 * sin_theta + x1 * cos_theta);
75
101
  }
76
102
 
77
- template <typename T, bool has_ff>
78
- static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
79
- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
80
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
81
- const sycl::nd_item<3> & item_ct1) {
82
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
83
-
84
- if (i0 >= ne0) {
103
+ template <bool forward, bool has_ff, typename T, typename D>
104
+ static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
105
+ const int ne02, const int s01, const int s02,
106
+ const int s03, const int s1, const int s2, const int s3,
107
+ const int n_dims, const int32_t *pos,
108
+ const float freq_scale, const float ext_factor,
109
+ const float attn_factor, const rope_corr_dims corr_dims,
110
+ const float theta_scale, const float *freq_factors,
111
+ const int64_t *row_indices, const int set_rows_stride) {
112
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
113
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
114
+ item_ct1.get_local_id(1));
115
+
116
+ if (i0 >= ne00) {
85
117
  return;
86
118
  }
87
119
 
88
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
120
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
121
+ item_ct1.get_local_id(2);
89
122
 
90
- const int row0 = row % ne1;
91
- const int channel0 = row / ne1;
123
+ const uint32_t i3 = row_dst / (ne01 * ne02);
124
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
125
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
92
126
 
93
- const int i = row * ne0 + i0 / 2;
94
- const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
127
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
128
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
129
+
130
+ if (set_rows_stride != 0) {
131
+ idst = i1 * s1 + i0 / 2;
132
+ idst += row_indices[i2] * set_rows_stride;
133
+ }
95
134
 
96
135
  if (i0 >= n_dims) {
97
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
136
+ dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
137
+ dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
138
+
98
139
  return;
99
140
  }
100
141
 
101
- const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
142
+ const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
102
143
 
103
144
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
104
145
 
105
146
  float cos_theta;
106
147
  float sin_theta;
107
148
 
108
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
149
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
150
+ ext_factor, attn_factor, cos_theta, sin_theta);
109
151
 
110
- const float x0 = x[i2 + 0];
111
- const float x1 = x[i2 + n_dims / 2];
152
+ const float x0 = x[ix + 0];
153
+ const float x1 = x[ix + n_dims / 2];
112
154
 
113
- dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
114
- dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
155
+ dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
156
+ dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
115
157
  }
116
158
 
117
- template <typename T, bool has_ff>
118
- static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
119
- const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121
- const float theta_scale, const float * freq_factors, const mrope_sections sections,
122
- const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
123
- // get index pos
124
- const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
125
- if (i0 >= ne0) {
159
+ template <bool forward, bool has_ff, typename T>
160
+ static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
161
+ const int ne02, const int s01, const int s02,
162
+ const int s03, const int s1, const int s2, const int s3,
163
+ const int n_dims, const int32_t *pos,
164
+ const float freq_scale, const float ext_factor,
165
+ const float attn_factor, const rope_corr_dims corr_dims,
166
+ const float theta_scale, const float *freq_factors,
167
+ const mrope_sections sections, const bool is_imrope) {
168
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
169
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
170
+ item_ct1.get_local_id(1));
171
+
172
+ if (i0 >= ne00) {
126
173
  return;
127
174
  }
128
- const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
129
175
 
130
- const int row_x = row_dst % ne1;
131
- const int channel_x = row_dst / ne1;
132
- const int idst = (row_dst * ne0) + (i0 / 2);
133
- const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
176
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
177
+ item_ct1.get_local_id(2);
178
+
179
+ const uint32_t i3 = row_dst / (ne01 * ne02);
180
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
181
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
182
+
183
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
184
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
134
185
 
135
186
  if (i0 >= n_dims) {
136
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
187
+ dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
188
+ dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
189
+
137
190
  return;
138
191
  }
139
192
 
140
- const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
193
+ const int sect_dims =
194
+ sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141
195
  const int sec_w = sections.v[1] + sections.v[0];
142
196
  const int sector = (i0 / 2) % sect_dims;
143
197
 
144
-
145
198
  float theta_base = 0.0;
146
199
  if (is_imrope) {
147
- if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
148
- theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
149
- } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
150
- theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
151
- } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
152
- theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
200
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
201
+ theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
202
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
203
+ theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
204
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
205
+ theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
153
206
  } else {
154
- theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
207
+ theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
155
208
  }
156
209
  } else {
157
210
  if (sector < sections.v[0]) {
158
- theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
159
- }
160
- else if (sector >= sections.v[0] && sector < sec_w) {
161
- theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
162
- }
163
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
164
- theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
165
- }
166
- else if (sector >= sec_w + sections.v[2]) {
167
- theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
211
+ theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
212
+ } else if (sector >= sections.v[0] && sector < sec_w) {
213
+ theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
214
+ } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
215
+ theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
216
+ } else if (sector >= sec_w + sections.v[2]) {
217
+ theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
168
218
  }
169
219
  }
170
220
 
171
221
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
172
- float cos_theta;
173
- float sin_theta;
174
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
175
- const float x0 = x[ix + 0];
176
- const float x1 = x[ix + n_dims/2];
177
222
 
178
- // store results in dst
179
- dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
180
- dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
181
- }
223
+ float cos_theta;
224
+ float sin_theta;
182
225
 
226
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
227
+ ext_factor, attn_factor, cos_theta, sin_theta);
183
228
 
229
+ const float x0 = x[ix + 0];
230
+ const float x1 = x[ix + n_dims / 2];
231
+
232
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
233
+ dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
234
+ }
184
235
 
185
- template <typename T, bool has_ff>
186
- static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
187
- const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
188
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
189
- const float theta_scale, const float * freq_factors, const mrope_sections sections,
190
- const sycl::nd_item<3> & item_ct1) {
191
- // get index pos
192
- const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
193
- if (i0 >= ne0) {
236
+ template <bool forward, bool has_ff, typename T>
237
+ static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
238
+ const int ne02, const int s01, const int s02,
239
+ const int s03, const int s1, const int s2, const int s3,
240
+ const int n_dims, const int32_t *pos,
241
+ const float freq_scale, const float ext_factor,
242
+ const float attn_factor, const rope_corr_dims corr_dims,
243
+ const float theta_scale, const float *freq_factors,
244
+ const mrope_sections sections) {
245
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
246
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
247
+ item_ct1.get_local_id(1));
248
+
249
+ if (i0 >= ne00) {
194
250
  return;
195
251
  }
196
- const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
197
- const int row_x = row_dst % ne1;
198
- const int channel_x = row_dst / ne1;
199
- const int idst = (row_dst * ne0) + (i0 / 2);
200
- const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
252
+
253
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
254
+ item_ct1.get_local_id(2);
255
+
256
+ const uint32_t i3 = row_dst / (ne01 * ne02);
257
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
258
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
259
+
260
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
261
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
201
262
 
202
263
  const int sect_dims = sections.v[0] + sections.v[1];
203
- const int sector = (i0 / 2) % sect_dims;
264
+ const int sec_w = sections.v[1] + sections.v[0];
265
+ const int sector = (i0 / 2) % sect_dims;
204
266
 
205
- float theta_base = 0.0f;
267
+ float theta_base = 0.0;
206
268
  if (sector < sections.v[0]) {
207
269
  const int p = sector;
208
- theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
209
- } else {
210
- // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
270
+ theta_base = pos[i2] * dpct::pow(theta_scale, p);
271
+ } else if (sector >= sections.v[0] && sector < sec_w) {
211
272
  const int p = sector - sections.v[0];
212
- theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
273
+ theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
213
274
  }
214
275
 
215
276
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
216
- float cos_theta;
217
- float sin_theta;
218
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
277
+
278
+ float cos_theta;
279
+ float sin_theta;
280
+
281
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
282
+ ext_factor, attn_factor, cos_theta, sin_theta);
283
+
219
284
  const float x0 = x[ix + 0];
220
285
  const float x1 = x[ix + n_dims];
221
286
 
222
- // store results in dst
223
- dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
287
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
224
288
  dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
225
289
  }
226
290
 
227
- template <typename T>
228
- static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
229
- const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
230
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
231
- const float * freq_factors, queue_ptr stream) {
232
- GGML_ASSERT(ne0 % 2 == 0);
233
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
234
- const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
235
- const sycl::range<3> block_nums(1, num_blocks_x, nr);
291
+ template <bool forward, typename T, typename D>
292
+ static void
293
+ rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
294
+ const int ne02, const int s01, const int s02, const int s03,
295
+ const int s1, const int s2, const int s3, const int n_dims,
296
+ const int nr, const int32_t *pos, const float freq_scale,
297
+ const float freq_base, const float ext_factor,
298
+ const float attn_factor, const rope_corr_dims corr_dims,
299
+ const float *freq_factors, const int64_t *row_indices,
300
+ const int set_rows_stride, dpct::queue_ptr stream) {
301
+ GGML_ASSERT(ne00 % 2 == 0);
302
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
303
+ const int n_blocks_x =
304
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
305
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
236
306
 
237
307
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
238
308
 
239
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
240
-
241
309
  if (freq_factors == nullptr) {
242
- /*
243
- DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
244
- the limit. To get the device limit, query
245
- info::device::max_work_group_size. Adjust the work-group size if needed.
246
- */
247
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
248
- rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
249
- theta_scale, freq_factors, item_ct1);
250
- });
310
+ stream->parallel_for(
311
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
312
+ [=](sycl::nd_item<3> item_ct1) {
313
+ GGML_UNUSED(item_ct1);
314
+ rope_norm<forward, false>(
315
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
316
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
317
+ theta_scale, freq_factors, row_indices, set_rows_stride);
318
+ });
251
319
  } else {
252
- /*
253
- DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
254
- the limit. To get the device limit, query
255
- info::device::max_work_group_size. Adjust the work-group size if needed.
256
- */
257
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
258
- rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
259
- theta_scale, freq_factors, item_ct1);
260
- });
320
+ stream->parallel_for(
321
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
322
+ [=](sycl::nd_item<3> item_ct1) {
323
+ GGML_UNUSED(item_ct1);
324
+ rope_norm<forward, true>(
325
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
326
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
327
+ theta_scale, freq_factors, row_indices, set_rows_stride);
328
+ });
261
329
  }
262
330
  }
263
331
 
264
- template <typename T>
265
- static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
266
- const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
267
- const float freq_base, const float ext_factor, const float attn_factor,
268
- const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
269
- GGML_ASSERT(ne0 % 2 == 0);
270
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
271
- const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
272
- const sycl::range<3> block_nums(1, num_blocks_x, nr);
332
+ template <bool forward, typename T, typename D>
333
+ static void
334
+ rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
335
+ const int ne02, const int s01, const int s02, const int s03,
336
+ const int s1, const int s2, const int s3, const int n_dims,
337
+ const int nr, const int32_t *pos, const float freq_scale,
338
+ const float freq_base, const float ext_factor,
339
+ const float attn_factor, const rope_corr_dims corr_dims,
340
+ const float *freq_factors, const int64_t *row_indices,
341
+ const int set_rows_stride, dpct::queue_ptr stream) {
342
+ GGML_ASSERT(ne00 % 2 == 0);
343
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
344
+ const int n_blocks_x =
345
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
346
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
273
347
 
274
348
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
275
349
 
276
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
277
-
278
350
  if (freq_factors == nullptr) {
279
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
280
- rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
281
- theta_scale, freq_factors, item_ct1);
282
- });
351
+ stream->parallel_for(
352
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
353
+ [=](sycl::nd_item<3> item_ct1) {
354
+ GGML_UNUSED(item_ct1);
355
+ rope_neox<forward, false>(
356
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
357
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
358
+ theta_scale, freq_factors, row_indices, set_rows_stride);
359
+ });
283
360
  } else {
284
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
285
- rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
286
- theta_scale, freq_factors, item_ct1);
287
- });
361
+ stream->parallel_for(
362
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
363
+ [=](sycl::nd_item<3> item_ct1) {
364
+ GGML_UNUSED(item_ct1);
365
+ rope_neox<forward, true>(
366
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
367
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
368
+ theta_scale, freq_factors, row_indices, set_rows_stride);
369
+ });
288
370
  }
289
371
  }
290
372
 
291
- template <typename T>
292
- static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
293
- const size_t s2, const int n_dims, const int nr, const int32_t * pos,
294
- const float freq_scale, const float freq_base, const float ext_factor,
295
- const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
296
- const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
297
- GGML_ASSERT(ne0 % 2 == 0);
298
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
299
- const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
300
- const sycl::range<3> grid_dims(1, n_blocks_y, nr);
301
- const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
302
-
303
- const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
304
- // Add FP16 capability check if T could be sycl::half
305
- if constexpr (std::is_same_v<T, sycl::half>) {
306
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
307
- }
308
- // launch kernel
373
+ template <bool forward, typename T>
374
+ static void
375
+ rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
376
+ const int ne02, const int s01, const int s02, const int s03,
377
+ const int s1, const int s2, const int s3, const int n_dims,
378
+ const int nr, const int32_t *pos, const float freq_scale,
379
+ const float freq_base, const float ext_factor,
380
+ const float attn_factor, const rope_corr_dims corr_dims,
381
+ const float *freq_factors, const mrope_sections sections,
382
+ const bool is_imrope, dpct::queue_ptr stream) {
383
+ GGML_ASSERT(ne00 % 2 == 0);
384
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
385
+ const int n_blocks_x =
386
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
387
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
388
+
389
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
390
+
309
391
  if (freq_factors == nullptr) {
310
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
311
- rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
312
- corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
313
- });
392
+ stream->parallel_for(
393
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
394
+ [=](sycl::nd_item<3> item_ct1) {
395
+ GGML_UNUSED(item_ct1);
396
+ rope_multi<forward, false, T>(
397
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
398
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
399
+ theta_scale, freq_factors, sections, is_imrope);
400
+ });
314
401
  } else {
315
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
316
- rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
317
- corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
318
- });
402
+ stream->parallel_for(
403
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
404
+ [=](sycl::nd_item<3> item_ct1) {
405
+ GGML_UNUSED(item_ct1);
406
+ rope_multi<forward, true, T>(
407
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
408
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
409
+ theta_scale, freq_factors, sections, is_imrope);
410
+ });
319
411
  }
320
412
  }
321
413
 
414
+ template <bool forward, typename T>
415
+ static void
416
+ rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
417
+ const int ne02, const int s01, const int s02, const int s03,
418
+ const int s1, const int s2, const int s3, const int n_dims,
419
+ const int nr, const int32_t *pos, const float freq_scale,
420
+ const float freq_base, const float ext_factor,
421
+ const float attn_factor, const rope_corr_dims corr_dims,
422
+ const float *freq_factors, const mrope_sections sections,
423
+ dpct::queue_ptr stream) {
424
+ GGML_ASSERT(ne00 % 2 == 0);
425
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
426
+ const int n_blocks_x =
427
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
428
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
322
429
 
430
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
323
431
 
324
-
325
- // rope vision
326
- template <typename T>
327
- static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
328
- const size_t s2, const int n_dims, const int nr, const int32_t * pos,
329
- const float freq_scale, const float freq_base, const float ext_factor,
330
- const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
331
- const mrope_sections sections, queue_ptr stream) {
332
- GGML_ASSERT(ne0 % 2 == 0);
333
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
334
- const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
335
- const sycl::range<3> grid_dims(1, n_blocks_y, nr);
336
- const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
337
-
338
- const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
339
- // Add FP16 capability check if T could be sycl::half
340
- if constexpr (std::is_same_v<T, sycl::half>) {
341
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
342
- }
343
- // launch kernel
344
432
  if (freq_factors == nullptr) {
345
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
346
- rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
347
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
348
- });
433
+ stream->parallel_for(
434
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
435
+ [=](sycl::nd_item<3> item_ct1) {
436
+ GGML_UNUSED(item_ct1);
437
+ rope_vision<forward, false, T>(
438
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
439
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
440
+ theta_scale, freq_factors, sections);
441
+ });
349
442
  } else {
350
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
351
- rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
352
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
353
- });
443
+ stream->parallel_for(
444
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
445
+ [=](sycl::nd_item<3> item_ct1) {
446
+ GGML_UNUSED(item_ct1);
447
+ rope_vision<forward, true, T>(
448
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
449
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
450
+ theta_scale, freq_factors, sections);
451
+ });
354
452
  }
355
453
  }
356
454
 
357
- inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
455
+ template <bool forward>
456
+ void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
457
+ const ggml_tensor *set_rows = nullptr) {
458
+ const ggml_tensor *src0 = dst->src[0];
459
+ const ggml_tensor *src1 = dst->src[1];
460
+ const ggml_tensor *src2 = dst->src[2];
461
+
462
+ const float *src0_d = (const float *)src0->data;
463
+ const float *src1_d = (const float *)src1->data;
464
+
465
+ void *dst_d = dst->data;
466
+ const int64_t *row_indices = nullptr;
467
+ ggml_type dst_type = dst->type;
468
+ int set_rows_stride = 0;
469
+
470
+ if (set_rows != nullptr) {
471
+ GGML_ASSERT(forward);
472
+ dst_d = set_rows->data;
473
+ row_indices = (const int64_t *)set_rows->src[1]->data;
474
+ dst_type = set_rows->type;
475
+ set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
476
+ }
477
+ dpct::queue_ptr stream = ctx.stream();
478
+
479
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
480
+ GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
481
+ GGML_ASSERT(src0->type == dst->type ||
482
+ (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
358
483
 
359
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
360
- GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
361
- GGML_ASSERT(dst->src[0]->type == dst->type);
362
- const int64_t ne00 = dst->src[0]->ne[0]; // head dims
363
- const int64_t ne01 = dst->src[0]->ne[1]; // num heads
364
- const int64_t ne02 = dst->src[0]->ne[2]; // num heads
365
- const int64_t nr = ggml_nrows(dst->src[0]);
484
+ const int64_t ne00 = src0->ne[0]; // head dims
485
+ const int64_t ne01 = src0->ne[1]; // num heads
486
+ const int64_t ne02 = src0->ne[2]; // num heads
487
+ const int64_t nr = ggml_nrows(src0);
366
488
 
367
- const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
368
- const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
489
+ const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
490
+ const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
491
+ const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
369
492
 
493
+ const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
494
+ const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
495
+ const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
370
496
 
371
- //const int n_past = ((int32_t *) dst->op_params)[0];
372
- const int n_dims = ((int32_t *) dst->op_params)[1];
373
- const int mode = ((int32_t *) dst->op_params)[2];
374
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
375
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
497
+ const int n_dims = ((int32_t *)dst->op_params)[1];
498
+ const int mode = ((int32_t *)dst->op_params)[2];
499
+ const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
376
500
  mrope_sections sections;
377
501
 
378
- // RoPE alteration for extended context
379
502
  float freq_base;
380
503
  float freq_scale;
381
504
  float ext_factor;
@@ -383,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
383
506
  float beta_fast;
384
507
  float beta_slow;
385
508
 
386
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
387
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
388
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
389
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
390
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
391
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
392
- memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
509
+ memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
510
+ memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
511
+ memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
512
+ memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
513
+ memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
514
+ memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
515
+ memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
393
516
 
394
517
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
395
518
  const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@@ -397,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
397
520
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
398
521
 
399
522
  if (is_mrope) {
400
- GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
523
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
524
+ sections.v[2] > 0);
401
525
  }
402
526
 
403
527
  if (is_vision) {
404
- GGML_ASSERT(n_dims == ne00/2);
528
+ GGML_ASSERT(n_dims == ne00 / 2);
405
529
  }
406
530
 
407
- const int32_t * pos = (const int32_t *) dst->src[1]->data;
531
+ const int32_t *pos = (const int32_t *)src1_d;
408
532
 
409
- const float * freq_factors = nullptr;
410
- if (dst->src[2] != nullptr) {
411
- freq_factors = (const float *) dst->src[2]->data;
533
+ const float *freq_factors = nullptr;
534
+ if (src2 != nullptr) {
535
+ freq_factors = (const float *)src2->data;
412
536
  }
413
537
 
414
538
  rope_corr_dims corr_dims;
415
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
416
-
417
- dpct::queue_ptr main_stream = ctx.stream();
418
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
539
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
540
+ beta_slow, corr_dims.v);
419
541
 
420
542
  // compute
421
543
  if (is_neox) {
422
544
  GGML_SYCL_DEBUG("%s: neox path\n", __func__);
423
- if (dst->src[0]->type == GGML_TYPE_F32) {
424
- rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
425
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
426
- } else if (dst->src[0]->type == GGML_TYPE_F16) {
427
- rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
428
- n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
429
- main_stream);
545
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
546
+ rope_neox_sycl<forward, float, float>(
547
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
548
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
549
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
550
+ set_rows_stride, stream);
551
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
552
+ rope_neox_sycl<forward, float, sycl::half>(
553
+ (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
554
+ s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
555
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
556
+ row_indices, set_rows_stride, stream);
557
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
558
+ rope_neox_sycl<forward, sycl::half, sycl::half>(
559
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
560
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
561
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
562
+ row_indices, set_rows_stride, stream);
430
563
  } else {
431
- GGML_ABORT("fatal error");
564
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
432
565
  }
433
566
  } else if (is_mrope && !is_vision) {
434
567
  GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
435
- if (dst->src[0]->type == GGML_TYPE_F16) {
436
- rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
437
- s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
438
- freq_factors, sections, is_imrope, main_stream);
439
- } else if (dst->src[0]->type == GGML_TYPE_F32) {
440
- rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
441
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
442
- is_imrope, main_stream);
568
+ if (src0->type == GGML_TYPE_F32) {
569
+ rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
570
+ ne00, ne01, ne02, s01, s02, s03, s1, s2,
571
+ s3, n_dims, nr, pos, freq_scale, freq_base,
572
+ ext_factor, attn_factor, corr_dims,
573
+ freq_factors, sections, is_imrope, stream);
574
+ } else if (src0->type == GGML_TYPE_F16) {
575
+ rope_multi_sycl<forward>(
576
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
577
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
578
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
579
+ sections, is_imrope, stream);
443
580
  } else {
444
581
  GGML_ABORT("Fatal error: Tensor type unsupported!");
445
582
  }
446
583
  } else if (is_vision) {
447
584
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
448
- if (dst->src[0]->type == GGML_TYPE_F16) {
449
- rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
450
- s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
451
- freq_factors, sections, main_stream);
452
- } else if (dst->src[0]->type == GGML_TYPE_F32) {
453
- rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
454
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
455
- main_stream);
585
+ if (src0->type == GGML_TYPE_F32) {
586
+ rope_vision_sycl<forward>(
587
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
588
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
589
+ ext_factor, attn_factor, corr_dims, freq_factors, sections,
590
+ stream);
591
+ } else if (src0->type == GGML_TYPE_F16) {
592
+ rope_vision_sycl<forward>(
593
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
594
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
595
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
596
+ sections, stream);
456
597
  } else {
457
598
  GGML_ABORT("Fatal error: Tensor type unsupported!");
458
599
  }
459
600
  } else {
460
601
  GGML_SYCL_DEBUG("%s: norm path\n", __func__);
461
- if (dst->src[0]->type == GGML_TYPE_F32) {
462
- rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
463
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
464
- } else if (dst->src[0]->type == GGML_TYPE_F16) {
465
- rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
466
- n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
467
- main_stream);
602
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
603
+ rope_norm_sycl<forward, float, float>(
604
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
605
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
606
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
607
+ set_rows_stride, stream);
608
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
609
+ rope_norm_sycl<forward, float, sycl::half>(
610
+ (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
611
+ s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
612
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
613
+ row_indices, set_rows_stride, stream);
614
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
615
+ rope_norm_sycl<forward, sycl::half, sycl::half>(
616
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
617
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
618
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
619
+ row_indices, set_rows_stride, stream);
468
620
  } else {
469
- GGML_ABORT("fatal error");
621
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
470
622
  }
471
623
  }
472
624
  }
473
625
 
474
- void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
626
+ void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
475
627
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
476
- ggml_sycl_op_rope(ctx, dst);
628
+
629
+ ggml_sycl_op_rope_impl<true>(ctx, dst);
477
630
  }
478
631
 
632
+ void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
633
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
634
+ ggml_sycl_op_rope_impl<false>(ctx, dst);
635
+ }
636
+
637
+ void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
638
+ ggml_tensor *set_rows) {
639
+ scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
640
+ ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
641
+ }