whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -15,18 +15,9 @@
15
15
 
16
16
  #include <sycl/sycl.hpp>
17
17
  #include <sycl/half_type.hpp>
18
- #include <syclcompat/math.hpp>
19
- #include <map>
20
-
21
- #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
18
  #include <oneapi/mkl.hpp>
23
- // Allow to use the same namespace for Intel oneMKL and oneMath
24
- namespace oneapi {
25
- namespace math = mkl;
26
- }
27
- #else
28
- #include <oneapi/math.hpp>
29
- #endif
19
+
20
+ #include <map>
30
21
 
31
22
  #include "ggml.h"
32
23
 
@@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
92
83
  }
93
84
 
94
85
  template <typename Ts> struct matrix_info_t {
95
- oneapi::math::transpose transpose_info[2];
86
+ oneapi::mkl::transpose transpose_info[2];
96
87
  Ts value_info[2];
97
88
  std::int64_t size_info[3];
98
89
  std::int64_t ld_info[3];
99
90
  std::int64_t groupsize_info;
100
91
  };
101
92
 
102
- inline auto get_onemath_backend(sycl::queue& queue)
103
- #if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
104
- -> sycl::queue&
105
- #endif
106
- {
107
- // If the backend is known at compile-time, use oneMath backend_selector to use
108
- // compile-time dispatching and avoid the need to dlopen libraries. Otherwise
109
- // fallback to runtime dispatching.
110
- #if defined(GGML_SYCL_NVIDIA)
111
- return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
112
- #elif defined(GGML_SYCL_AMD)
113
- return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
114
- #elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
115
- return queue;
116
- #else
117
- static_assert(false, "Unsupported backend");
118
- #endif
119
- }
120
-
121
93
  namespace dpct
122
94
  {
123
95
  typedef sycl::queue *queue_ptr;
@@ -1735,7 +1707,7 @@ namespace dpct
1735
1707
  namespace detail
1736
1708
  {
1737
1709
  template <class Ta, class Tb, class Tc, class Ts>
1738
- inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1710
+ inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
1739
1711
  int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1740
1712
  const void * beta, void * c, int ldc) {
1741
1713
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
@@ -1743,7 +1715,7 @@ namespace dpct
1743
1715
  auto data_a = get_memory<const Ta>(a);
1744
1716
  auto data_b = get_memory<const Tb>(b);
1745
1717
  auto data_c = get_memory<Tc>(c);
1746
- oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1718
+ oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
1747
1719
  lda, data_b, ldb, beta_value, data_c, ldc);
1748
1720
  }
1749
1721
 
@@ -1775,7 +1747,7 @@ namespace dpct
1775
1747
  };
1776
1748
 
1777
1749
  template <class Ta, class Tb, class Tc, class Ts>
1778
- inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1750
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1779
1751
  int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1780
1752
  int ldb, const void * beta, void ** c, int ldc, int batch_size,
1781
1753
  matrix_info_t<float> * matrix_info) {
@@ -1794,8 +1766,8 @@ namespace dpct
1794
1766
  matrix_info->ld_info[2] = ldc;
1795
1767
  matrix_info->groupsize_info = batch_size;
1796
1768
 
1797
- sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1798
- get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
1769
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1770
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1799
1771
  matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1800
1772
  reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1801
1773
  reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
@@ -1804,7 +1776,7 @@ namespace dpct
1804
1776
  }
1805
1777
 
1806
1778
  template <class Ta, class Tb, class Tc, class Ts>
1807
- inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1779
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1808
1780
  int m, int n, int k, const void * alpha, const void * a, int lda,
1809
1781
  long long int stride_a, const void * b, int ldb, long long int stride_b,
1810
1782
  const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
@@ -1813,7 +1785,7 @@ namespace dpct
1813
1785
  auto data_a = get_memory<const Ta>(a);
1814
1786
  auto data_b = get_memory<const Tb>(b);
1815
1787
  auto data_c = get_memory<Tc>(c);
1816
- oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
1788
+ oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
1817
1789
  data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1818
1790
  data_c, ldc, stride_c, batch_size);
1819
1791
  }
@@ -2300,7 +2272,7 @@ namespace dpct
2300
2272
  sycl::range<3>(x, y, 1), direction);
2301
2273
  }
2302
2274
 
2303
- inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2275
+ inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
2304
2276
  int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2305
2277
  library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2306
2278
  library_data_t scaling_type) {
@@ -2367,7 +2339,7 @@ namespace dpct
2367
2339
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2368
2340
  library_data_t::real_float, library_data_t::real_float):
2369
2341
  {
2370
- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2342
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2371
2343
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2372
2344
  break;
2373
2345
  }
@@ -2406,7 +2378,7 @@ namespace dpct
2406
2378
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2407
2379
  library_data_t::real_bfloat16, library_data_t::real_float):
2408
2380
  {
2409
- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2381
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2410
2382
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2411
2383
  break;
2412
2384
  }
@@ -2448,7 +2420,7 @@ namespace dpct
2448
2420
  /// \param [in] ldc Leading dimension of C.
2449
2421
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2450
2422
  /// \param [in] scaling_type Data type of the scaling factors.
2451
- inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2423
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2452
2424
  int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2453
2425
  const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2454
2426
  library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2486,7 +2458,7 @@ namespace dpct
2486
2458
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2487
2459
  library_data_t::real_bfloat16, library_data_t::real_float):
2488
2460
  {
2489
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2461
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2490
2462
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2491
2463
  break;
2492
2464
  }
@@ -2494,7 +2466,7 @@ namespace dpct
2494
2466
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2495
2467
  library_data_t::real_float, library_data_t::real_float):
2496
2468
  {
2497
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2469
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2498
2470
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2499
2471
  break;
2500
2472
  }
@@ -2570,7 +2542,7 @@ namespace dpct
2570
2542
  /// \param [in] stride_c Stride between the different C matrices.
2571
2543
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2572
2544
  /// \param [in] scaling_type Data type of the scaling factors.
2573
- inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2545
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2574
2546
  int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2575
2547
  long long int stride_a, const void * b, library_data_t b_type, int ldb,
2576
2548
  long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
@@ -2643,7 +2615,7 @@ namespace dpct
2643
2615
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2644
2616
  library_data_t::real_bfloat16, library_data_t::real_float):
2645
2617
  {
2646
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2618
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2647
2619
  q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2648
2620
  batch_size);
2649
2621
  break;
@@ -2652,7 +2624,7 @@ namespace dpct
2652
2624
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2653
2625
  library_data_t::real_float, library_data_t::real_float):
2654
2626
  {
2655
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2627
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2656
2628
  q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2657
2629
  batch_size);
2658
2630
  break;
@@ -3025,6 +2997,778 @@ namespace dpct
3025
2997
  return 0;
3026
2998
  }
3027
2999
 
3000
+ template <int n_nondefault_params, int n_default_params, typename T>
3001
+ class args_selector;
3002
+
3003
+ /// args_selector is a helper class for extracting arguments from an
3004
+ /// array of pointers to arguments or buffer of arguments to pass to a
3005
+ /// kernel function.
3006
+ ///
3007
+ /// \param R(Ts...) The type of the kernel
3008
+ /// \param n_nondefault_params The number of nondefault parameters of the
3009
+ /// kernel (excluding parameters that like sycl::nd_item, etc.) \param
3010
+ /// n_default_params The number of default parameters of the kernel
3011
+ ///
3012
+ /// Example usage:
3013
+ /// With the following kernel:
3014
+ /// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
3015
+ /// f=.1) {}
3016
+ /// and with the declaration:
3017
+ /// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
3018
+ /// we have:
3019
+ /// selector.get<0>() returns a reference to sycl::float*,
3020
+ /// selector.get<1>() returns a reference to int,
3021
+ /// selector.get<2>() returns a reference to float
3022
+ template <int n_nondefault_params, int n_default_params, typename R,
3023
+ typename... Ts>
3024
+ class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
3025
+ private:
3026
+ void **kernel_params;
3027
+ char *args_buffer;
3028
+
3029
+ template <int i> static constexpr int account_for_default_params() {
3030
+ constexpr int n_total_params = sizeof...(Ts);
3031
+ if constexpr (i >= n_nondefault_params) {
3032
+ return n_total_params - n_default_params +
3033
+ (i - n_nondefault_params);
3034
+ } else {
3035
+ return i;
3036
+ }
3037
+ }
3038
+
3039
+ public:
3040
+ /// Get the type of the ith argument of R(Ts...)
3041
+ /// \param [in] i Index of parameter to get
3042
+ /// \returns Type of ith parameter
3043
+ template <int i>
3044
+ using arg_type = std::tuple_element_t<account_for_default_params<i>(),
3045
+ std::tuple<Ts...>>;
3046
+ static constexpr int params_num = sizeof...(Ts);
3047
+
3048
+ private:
3049
+ template <int i> static constexpr int get_offset() {
3050
+ if constexpr (i == 0) {
3051
+ // we can assume args_buffer is properly aligned to the
3052
+ // first argument
3053
+ return 0;
3054
+ } else {
3055
+ constexpr int prev_off = get_offset<i - 1>();
3056
+ constexpr int prev_past_end =
3057
+ prev_off + sizeof(arg_type<i - 1>);
3058
+ using T = arg_type<i>;
3059
+ // is the past-the-end of the i-1st element properly aligned
3060
+ // with the ith element's alignment?
3061
+ if constexpr (prev_past_end % alignof(T) == 0) {
3062
+ return prev_past_end;
3063
+ }
3064
+ // otherwise bump prev_past_end to match alignment
3065
+ else {
3066
+ return prev_past_end +
3067
+ (alignof(T) - (prev_past_end % alignof(T)));
3068
+ }
3069
+ }
3070
+ }
3071
+
3072
+ static char *get_args_buffer(void **extra) {
3073
+ if (!extra)
3074
+ return nullptr;
3075
+ for (; (std::size_t)*extra != 0; ++extra) {
3076
+ if ((std::size_t)*extra == 1) {
3077
+ return static_cast<char *>(*(extra + 1));
3078
+ }
3079
+ }
3080
+ return nullptr;
3081
+ }
3082
+
3083
+ public:
3084
+ /// If kernel_params is nonnull, then args_selector will
3085
+ /// extract arguments from kernel_params. Otherwise, it
3086
+ /// will extract them from extra.
3087
+ /// \param [in] kernel_params Array of pointers to arguments
3088
+ /// a or null pointer.
3089
+ /// \param [in] extra Array containing pointer to argument buffer.
3090
+ args_selector(void **kernel_params, void **extra)
3091
+ : kernel_params(kernel_params),
3092
+ args_buffer(get_args_buffer(extra)) {}
3093
+
3094
+ /// Get a reference to the ith argument extracted from kernel_params
3095
+ /// or extra.
3096
+ /// \param [in] i Index of argument to get
3097
+ /// \returns Reference to the ith argument
3098
+ template <int i> arg_type<i> &get() {
3099
+ if (kernel_params) {
3100
+ return *static_cast<arg_type<i> *>(kernel_params[i]);
3101
+ } else {
3102
+ return *reinterpret_cast<arg_type<i> *>(args_buffer +
3103
+ get_offset<i>());
3104
+ }
3105
+ }
3106
+ }; // COPY from DPCT head file
3107
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
3108
+
3109
+ /// Utility class for launching SYCL kernels through kernel
3110
+ /// function wrapper.
3111
+ /// For example:
3112
+ /// A SYCL kernel function:
3113
+ /// void kernel_func(int *ptr, sycl::nd_item<3> item);
3114
+ /// Kernel function wrapper:
3115
+ /// void kernel_func_wrapper(int *ptr) {
3116
+ /// sycl::queue queue = *dpct::kernel_launcher::_que;
3117
+ /// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
3118
+ /// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
3119
+ /// queue.parallel_for(
3120
+ /// nr,
3121
+ /// [=](sycl::nd_item<3> item_ct1) {
3122
+ /// kernel_func(ptr, item_ct1);
3123
+ /// });
3124
+ /// }
3125
+ /// Then launch the kernel through wrapper like:
3126
+ /// typedef void(*fpt)(int *);
3127
+ /// fpt fp = kernel_func_wrapper;
3128
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
3129
+ /// device_ptr);
3130
+ /// If the origin function type is erased, then need to register it first:
3131
+ /// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
3132
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
3133
+ /// 0, 0);
3134
+ class kernel_launcher {
3135
+ template <typename FuncT, typename ArgSelector, std::size_t... Index>
3136
+ static void launch_helper(FuncT &&func, ArgSelector &selector,
3137
+ std::index_sequence<Index...>) {
3138
+ func(selector.template get<Index>()...);
3139
+ }
3140
+ static void set_execution_config(dim3 group_range, dim3 local_range,
3141
+ unsigned int local_mem_size,
3142
+ queue_ptr que) {
3143
+ if (que) {
3144
+ _que = que;
3145
+ } else {
3146
+ _que = &get_default_queue();
3147
+ }
3148
+ _nr = sycl::nd_range<3>(
3149
+ static_cast<sycl::range<3>>(group_range * local_range),
3150
+ static_cast<sycl::range<3>>(local_range));
3151
+ _local_mem_size = local_mem_size;
3152
+
3153
+
3154
+ };
3155
+ static inline std::mutex kernel_function_ptr_map_mutex;
3156
+
3157
+ public:
3158
+ /// Variables for storing execution configuration.
3159
+ static inline thread_local sycl::queue *_que = nullptr;
3160
+ static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
3161
+ static inline thread_local unsigned int _local_mem_size = 0;
3162
+ /// Map for retrieving launchable functor from a raw pointer.
3163
+ static inline std::map<
3164
+ const void *,
3165
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
3166
+ kernel_function_ptr_map = {};
3167
+
3168
+ /// Registers a kernel function pointer with a corresponding launchable
3169
+ /// functor.
3170
+ /// \param [in] func Pointer to the kernel function.
3171
+ /// \param [in] launcher Functor to handle kernel invocation.
3172
+ static void register_kernel_ptr(
3173
+ const void *func,
3174
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
3175
+ launcher) {
3176
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
3177
+ kernel_function_ptr_map[func] = std::move(launcher);
3178
+ }
3179
+ /// Launches a kernel function with arguments provided directly through
3180
+ /// kernel function wrapper.
3181
+ /// \tparam FuncT Type of the kernel function wrapper.
3182
+ /// \tparam ArgsT Types of kernel arguments.
3183
+ /// \param [in] func Pointer to the kernel function wrapper.
3184
+ /// \param [in] group_range SYCL group range.
3185
+ /// \param [in] local_range SYCL local range.
3186
+ /// \param [in] local_mem_size The size of local memory required by the
3187
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
3188
+ /// \param [in] args Kernel arguments.
3189
+ template <typename FuncT, typename... ArgsT>
3190
+ static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
3191
+ launch(FuncT *func, dim3 group_range, dim3 local_range,
3192
+ unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
3193
+ set_execution_config(group_range, local_range, local_mem_size, que);
3194
+ func(args...);
3195
+ }
3196
+ /// Launches a kernel function through registered kernel function
3197
+ /// wrapper. \param [in] func Pointer to the registered kernel function
3198
+ /// wrapper. \param [in] group_range SYCL group range. \param [in]
3199
+ /// local_range SYCL local range. \param [in] args Array of pointers to
3200
+ /// kernel arguments. \param [in] local_mem_size The size of local
3201
+ /// memory required by the kernel function. \param [in] que SYCL queue
3202
+ /// used to execute kernel.
3203
+ static void launch(const void *func, dim3 group_range, dim3 local_range,
3204
+ void **args, unsigned int local_mem_size,
3205
+ queue_ptr que) {
3206
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
3207
+ auto Iter = kernel_function_ptr_map.find(func);
3208
+ if (Iter == kernel_function_ptr_map.end()) {
3209
+ throw std::runtime_error("dpct::launch() : no registered "
3210
+ "kernel function wrapper found.");
3211
+ }
3212
+ (Iter->second)(group_range, local_range, args, local_mem_size, que);
3213
+ }
3214
+ /// Launches a kernel function with packed arguments through kernel
3215
+ /// function wrapper.
3216
+ /// \tparam FuncT Type of the kernel function wrapper.
3217
+ /// \param [in] func Pointer to the kernel function wrapper.
3218
+ /// \param [in] group_range SYCL group range.
3219
+ /// \param [in] local_range SYCL local range.
3220
+ /// \param [in] args Array of pointers to kernel arguments.
3221
+ /// \param [in] local_mem_size The size of local memory required by the
3222
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
3223
+ template <typename FuncT>
3224
+ static std::enable_if_t<std::is_function_v<FuncT>, void>
3225
+ launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
3226
+ unsigned int local_mem_size, queue_ptr que) {
3227
+ constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
3228
+ set_execution_config(group_range, local_range, local_mem_size, que);
3229
+ args_selector<p_num, p_num, FuncT> selector(args, nullptr);
3230
+ launch_helper(func, selector, std::make_index_sequence<p_num>{});
3231
+ }
3232
+ }; // COPY from DPCT head file
3233
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
3234
+
3235
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
3236
+ template <typename T>
3237
+ T select_from_sub_group(
3238
+ sycl::sub_group g,
3239
+ T x,
3240
+ int remote_local_id,
3241
+ int logical_sub_group_size = 32) {
3242
+ unsigned int start_index = g.get_local_linear_id() /
3243
+ logical_sub_group_size *
3244
+ logical_sub_group_size;
3245
+ return sycl::select_from_group(
3246
+ g, x, start_index + remote_local_id % logical_sub_group_size);
3247
+ }
3248
+
3249
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
3250
+ template <typename T>
3251
+ void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
3252
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
3253
+ int lane = sg.get_local_linear_id();
3254
+
3255
+ int lane_group8_row = lane / 8;
3256
+ int lane_group8_col = lane % 8;
3257
+
3258
+ if (!trans) {
3259
+ // calculate the source lane
3260
+ int src_lane = 2 * lane_group8_row;
3261
+ if (lane_group8_col >= 4)
3262
+ src_lane += 1;
3263
+
3264
+ // Broadcast the address from the source lane
3265
+ auto recv_addr_uintp =
3266
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
3267
+
3268
+ // Cast the received address from uintptr_t to the type of 'm'
3269
+ auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
3270
+
3271
+ // Non-transposed load
3272
+ *m = recv_addr[lane_group8_col % 4];
3273
+ } else {
3274
+ // calculate the source lane
3275
+ int src_lane = (lane % 4) * 2;
3276
+
3277
+ // Broadcast the address from the source lane
3278
+ auto recv_addr_uintp_1 =
3279
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
3280
+ auto recv_addr_uintp_2 =
3281
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
3282
+
3283
+ // Cast the received address from uintptr_t to 'half *'
3284
+ auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
3285
+ auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
3286
+
3287
+ // Transposed load
3288
+ int index = lane / 4;
3289
+ sycl::half val0 = recv_addr_1[index];
3290
+ sycl::half val1 = recv_addr_2[index];
3291
+
3292
+ // Combine the two 16-bits into one 32-bit value
3293
+ sycl::half2 val = sycl::half2(val0, val1);
3294
+ *m = *reinterpret_cast<T*>(&val);
3295
+ }
3296
+ }
3297
+
3298
+ template <typename T>
3299
+ void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
3300
+ // Load 1st matrix
3301
+ ldmatrix(addr, m1, trans, 0);
3302
+ // Load 2nd matrix
3303
+ ldmatrix(addr, m2, trans, 1);
3304
+ }
3305
+
3306
+ template <typename T>
3307
+ void ldmatrix(
3308
+ uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
3309
+ // Load 1st matrix
3310
+ ldmatrix(addr, m1, trans, 0);
3311
+ // Load 2nd matrix
3312
+ ldmatrix(addr, m2, trans, 1);
3313
+ // Load 3rd matrix
3314
+ ldmatrix(addr, m3, trans, 2);
3315
+ // Load 4th matrix
3316
+ ldmatrix(addr, m4, trans, 3);
3317
+ }
3318
+
3319
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
3320
+
3321
+ /// A helper struct that defines the pack type for the input matrix
3322
+ /// fragments
3323
+ /// of mma() function based on the type of input matrix fragments.
3324
+ /// The MMAType struct is specialized for different types of input matrices.
3325
+ /// Currently, the specialization for f16, bf16 and s8 types is defined
3326
+ /// below. \tparam [in] T The type of the input matrix fragments
3327
+ template <typename T>
3328
+ struct MMAType {
3329
+ using PackType = uint32_t;
3330
+ };
3331
+
3332
+ /// Each work item of a sub-group (limited to size 32) calling this function
3333
+ /// calculates a subset fragment for the output matrix D using MAD operation
3334
+ /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
3335
+ /// types:
3336
+ /// - m8n8k4 (f32.f16.f16.f32)
3337
+ /// - m8n8k16 (s32.s8.s8.s32)
3338
+ /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
3339
+ /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
3340
+ /// - m16n8k32 (s32.s8.s8.s32)
3341
+ /// Here, m, n & k define the shapes of A, B & C matrices respectively
3342
+ /// (A = [m x k], B = [k x n], C = [m x n]).
3343
+ /// \tparam [in] M The rows of A, C & D matrices
3344
+ /// \tparam [in] N The columns of B, C, D matrices
3345
+ /// \tparam [in] K The columns & rows of A & B matrices respectively
3346
+ /// \tparam [in] ABType The type of the input matrix (A & B) fragment
3347
+ /// \tparam [in] CDType The type of the output matrix (C & D) fragment
3348
+ /// \param [out] d_mat_frag The fragment of the output matrix D to store the
3349
+ /// result of A * B + C
3350
+ /// \param [in] a_mat_frag The fragment of the input matrix A to be
3351
+ /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
3352
+ /// the input matrix B to be multiplied with A matrix fragment \param [in]
3353
+ /// c_mat_frag The fragment of the input matrix C to be added with the
3354
+ /// result of A * B fragments
3355
+ template <int M, int N, int K, typename ABType, typename CDType>
3356
+ void mma(
3357
+ volatile void** d_mat_frag,
3358
+ void* a_mat_frag,
3359
+ void* b_mat_frag,
3360
+ void* c_mat_frag) {
3361
+ auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
3362
+ auto a =
3363
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
3364
+ auto b =
3365
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
3366
+ auto c = reinterpret_cast<CDType*>(c_mat_frag);
3367
+
3368
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
3369
+ int lane = sg.get_local_linear_id();
3370
+
3371
+ static_assert(
3372
+ (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
3373
+ (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
3374
+ (M == 16 && N == 8 && K == 32),
3375
+ "Unsupported MMA shape!");
3376
+
3377
+ short row_load_offset = 4 * (lane >> 2);
3378
+ short col_load_offset = 8 * (lane % 4);
3379
+
3380
+ if constexpr (M == 8 && N == 8 && K == 4) {
3381
+ if constexpr (std::is_floating_point_v<CDType>) {
3382
+ col_load_offset = row_load_offset % 16;
3383
+
3384
+ // Init D matrix with fragments of C matrix
3385
+ *d[0] = c[0];
3386
+ *d[1] = c[1];
3387
+ *d[2] = c[2];
3388
+ *d[3] = c[3];
3389
+ *d[4] = c[4];
3390
+ *d[5] = c[5];
3391
+ *d[6] = c[6];
3392
+ *d[7] = c[7];
3393
+
3394
+ // Calculate the row and col offset indices to iterate through the row
3395
+ // & col fragments of A & B matrices
3396
+ int r_ind = (lane % 2) ? 1 : 0;
3397
+ int c_ind = ((lane % 4) / 2) ? 2 : 0;
3398
+
3399
+ // Each sub-group is responsible for computing a fragment size of 8*8
3400
+ // elements of matrix D for each of 4 MMA computations.
3401
+ // Each work item computes 8 elements of matrix D by gathering
3402
+ // their corresponding col & row matrix fragments of length k (4)
3403
+ // from A & B matrices respectively using below mapping logic:
3404
+ // row0 = (i % 4) if (lane < 16) else (i % 4) + 4
3405
+ // col0 = (lane % 4)
3406
+ // As each row & col fragment of A & B matrices is distributed across
3407
+ // 4 work items, each iteration of below loop loads a partial fragment
3408
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3409
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3410
+
3411
+ for (int i = 0; i < 4; i++) {
3412
+ // Load partial fragment from col0 of matrix A ({a0, a1})
3413
+ recv_a[0] =
3414
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3415
+ // Load partial fragment from col0 of matrix A ({a2, a3})
3416
+ recv_a[1] =
3417
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3418
+
3419
+ // Load partial fragment from row0 of matrix B ({b0, b1})
3420
+ recv_b[0] =
3421
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3422
+ // Load partial fragment from row0 of matrix B ({b2, b3})
3423
+ recv_b[1] =
3424
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3425
+
3426
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3427
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3428
+
3429
+ // Each work item calculates a partial product of A & B matrix
3430
+ // fragments and adds it to the corresponding D matrix fragment (for
3431
+ // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
3432
+ // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
3433
+ // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
3434
+ // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
3435
+ // d3 += col1{ a3 } * row0{ b3 }
3436
+ *d[0] +=
3437
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
3438
+ *d[1] += static_cast<float>(ra[r_ind]) *
3439
+ static_cast<float>(rb[c_ind + 1]);
3440
+ *d[2] += static_cast<float>(ra[r_ind + 2]) *
3441
+ static_cast<float>(rb[c_ind]);
3442
+ *d[3] += static_cast<float>(ra[r_ind + 2]) *
3443
+ static_cast<float>(rb[c_ind + 1]);
3444
+
3445
+ // Load partial fragment from row1 of matrix B ({b0, b1})
3446
+ recv_b[0] =
3447
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
3448
+ // Load partial fragment from row1 of matrix B ({b2, b3})
3449
+ recv_b[1] =
3450
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
3451
+
3452
+ // (for even work item indices)
3453
+ // d0 += col0{ a0 } * row1{ b0 }
3454
+ // d1 += col0{ a0 } * row1{ b1 }
3455
+ // d2 += col1{ a2 } * row1{ b0 }
3456
+ // d3 += col1{ a2 } * row1{ b1 }
3457
+ // (for odd work item indices)
3458
+ // d0 += col0{ a1 } * row1{ b2 }
3459
+ // d1 += col0{ a1 } * row1{ b3 }
3460
+ // d2 += col1{ a3 } * row1{ b2 }
3461
+ // d3 += col1{ a3 } * row1{ b3 }
3462
+ *d[4] +=
3463
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
3464
+ *d[5] += static_cast<float>(ra[r_ind]) *
3465
+ static_cast<float>(rb[c_ind + 1]);
3466
+ *d[6] += static_cast<float>(ra[r_ind + 2]) *
3467
+ static_cast<float>(rb[c_ind]);
3468
+ *d[7] += static_cast<float>(ra[r_ind + 2]) *
3469
+ static_cast<float>(rb[c_ind + 1]);
3470
+ }
3471
+ }
3472
+ } else if constexpr (M == 8 && N == 8 && K == 16) {
3473
+ if constexpr (std::is_integral_v<ABType>) {
3474
+ // Init D matrix with fragments of C matrix
3475
+ *d[0] = c[0];
3476
+ *d[1] = c[1];
3477
+
3478
+ // Each sub-group is responsible for computing a fragment size of 16*8
3479
+ // elements of matrix D.
3480
+ // Each work item computes 2 elements of matrix D by gathering
3481
+ // their corresponding row & col matrix fragments of length k (16)
3482
+ // from A & B matrices respectively using below mapping logic:
3483
+ // row0 = ((lane % 4) * 4) + i
3484
+ // col0 = (lane >> 2)
3485
+ // As each row & col fragment of A & B matrices is distributed across
3486
+ // 4 work items, each iteration of below loop loads a partial fragment
3487
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3488
+ for (int i = 0; i < 4; i++) {
3489
+ typename MMAType<ABType>::PackType recv_a, recv_b[2];
3490
+
3491
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3492
+ recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3493
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3494
+ recv_b[0] =
3495
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3496
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
3497
+ recv_b[1] =
3498
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3499
+
3500
+ auto a = reinterpret_cast<ABType*>(&recv_a);
3501
+ auto b = reinterpret_cast<ABType*>(recv_b);
3502
+
3503
+ // Each work item calculates a partial product of A & B matrix
3504
+ // fragments and adds it to the corresponding D matrix fragment d0
3505
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3506
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
3507
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
3508
+ // col1{ b0, b1, b2, b3 }
3509
+ for (int j = 0; j < 4; j++) {
3510
+ *d[0] += a[j] * b[j];
3511
+ *d[1] += a[j] * b[j + 4];
3512
+ }
3513
+ }
3514
+ }
3515
+ } else if constexpr (M == 16 && N == 8 && K == 8) {
3516
+ if constexpr (std::is_floating_point_v<CDType>) {
3517
+ // Init D matrix fragment with C matrix fragment
3518
+ *d[0] = c[0];
3519
+ *d[1] = c[1];
3520
+ *d[2] = c[2];
3521
+ *d[3] = c[3];
3522
+
3523
+ // Each sub-group is responsible for computing a fragment size of 16*8
3524
+ // elements of matrix D.
3525
+ // Each work item computes 4 elements of matrix D by gathering
3526
+ // their corresponding row & col matrix fragments of length k (8)
3527
+ // from A & B matrices respectively using below mapping logic:
3528
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3529
+ // col0 = (lane % 4) * 2 + (i & 0x1)
3530
+ // As each row & col fragment of A & B matrices is distributed across
3531
+ // 4 work items, each iteration of below loop loads a partial fragment
3532
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3533
+ for (int i = 0; i < 4; i++) {
3534
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3535
+
3536
+ // Load partial fragment from row0 of matrix A ({a0, a1})
3537
+ recv_a[0] =
3538
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3539
+ // Load partial fragment from row1 of matrix A ({a2, a3})
3540
+ recv_a[1] =
3541
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3542
+ // Load partial fragment from col0 of matrix B ({b0, b1})
3543
+ recv_b[0] =
3544
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3545
+ // Load partial fragment from col1 of matrix B ({b0, b1})
3546
+ recv_b[1] =
3547
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3548
+
3549
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3550
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3551
+
3552
+ // Each work item calculates a partial product of A & B matrix
3553
+ // fragments and adds it to the corresponding D matrix fragment d0
3554
+ // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
3555
+ // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
3556
+ // } * col1{ b0, b1 }
3557
+ for (int j = 0; j < 2; j++) {
3558
+ *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
3559
+ *d[1] +=
3560
+ static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
3561
+ *d[2] +=
3562
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
3563
+ *d[3] +=
3564
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
3565
+ }
3566
+ }
3567
+ }
3568
+ } else if constexpr (M == 16 && N == 8 && K == 16) {
3569
+ if constexpr (std::is_floating_point_v<CDType>) {
3570
+ // Init D matrix fragment with C matrix fragment
3571
+ *d[0] = c[0];
3572
+ *d[1] = c[1];
3573
+ *d[2] = c[2];
3574
+ *d[3] = c[3];
3575
+
3576
+ // Each sub-group is responsible for computing a fragment size of 16*8
3577
+ // elements of matrix D.
3578
+ // Each work item computes 4 elements of matrix D by gathering
3579
+ // their corresponding row & col matrix fragments of length k (8)
3580
+ // from A & B matrices respectively using below mapping logic:
3581
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3582
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
3583
+ // As each row & col fragment of A & B matrices is distributed across
3584
+ // 4 work items, each iteration of below loop loads a partial fragment
3585
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3586
+ for (int i = 0; i < 4; i++) {
3587
+ typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
3588
+
3589
+ // Load partial fragment from row0 of matrix A ({a0, a1})
3590
+ recv_a[0] =
3591
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3592
+ // Load partial fragment from row0 of matrix A ({a2, a3})
3593
+ recv_a[1] =
3594
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
3595
+ // Load partial fragment from row1 of matrix A ({a0, a1})
3596
+ recv_a[2] =
3597
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3598
+ // Load partial fragment from row1 of matrix A ({a2, a3})
3599
+ recv_a[3] =
3600
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
3601
+
3602
+ // Load partial fragment from col0 of matrix B ({b0, b1})
3603
+ recv_b[0] =
3604
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3605
+ // Load partial fragment from col0 of matrix B ({b2, b3})
3606
+ recv_b[1] =
3607
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3608
+ // Load partial fragment from col1 of matrix B ({b0, b1})
3609
+ recv_b[2] =
3610
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
3611
+ // Load partial fragment from col1 of matrix B ({b2, b3})
3612
+ recv_b[3] =
3613
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
3614
+
3615
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3616
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3617
+
3618
+ // Each work item calculates a partial product of A & B matrix
3619
+ // fragments and adds it to the corresponding D matrix fragment d0
3620
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3621
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
3622
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
3623
+ // col1{ b0, b1, b2, b3 }
3624
+ for (int j = 0; j < 4; j++) {
3625
+ *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
3626
+ *d[1] +=
3627
+ static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
3628
+ *d[2] +=
3629
+ static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
3630
+ *d[3] += static_cast<CDType>(ra[j + 4]) *
3631
+ static_cast<CDType>(rb[j + 4]);
3632
+ }
3633
+ }
3634
+ } else if constexpr (std::is_integral_v<ABType>) {
3635
+ // Init D matrix with fragments of C matrix
3636
+ *d[0] = c[0];
3637
+ *d[1] = c[1];
3638
+ *d[2] = c[2];
3639
+ *d[3] = c[3];
3640
+
3641
+ // Each sub-group is responsible for computing a fragment size of 16*8
3642
+ // elements of matrix D.
3643
+ // Each work item computes 4 elements of matrix D by gathering
3644
+ // their corresponding row & col matrix fragments of length k (8)
3645
+ // from A & B matrices respectively using below mapping logic:
3646
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3647
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
3648
+ // As each row & col fragment of A & B matrices is distributed across
3649
+ // 4 work items, each iteration of below loop loads a partial fragment
3650
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3651
+ for (int i = 0; i < 4; i++) {
3652
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3653
+
3654
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3655
+ recv_a[0] =
3656
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3657
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
3658
+ recv_a[1] =
3659
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3660
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3661
+ recv_b[0] =
3662
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3663
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
3664
+ recv_b[1] =
3665
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3666
+
3667
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3668
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3669
+
3670
+ // Each work item calculates a partial product of A & B matrix
3671
+ // fragments and adds it to the corresponding D matrix fragment d0
3672
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3673
+ // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
3674
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
3675
+ // col1{ b4, b5, b6, b7 }
3676
+ for (int i = 0; i < 4; i++) {
3677
+ *d[0] += ra[i] * rb[i];
3678
+ *d[1] += ra[i] * rb[i + 4];
3679
+ *d[2] += ra[i + 4] * rb[i];
3680
+ *d[3] += ra[i + 4] * rb[i + 4];
3681
+ }
3682
+ }
3683
+ }
3684
+ } else if constexpr (M == 16 && N == 8 && K == 32) {
3685
+ if constexpr (std::is_integral_v<ABType>) {
3686
+ // Init D matrix with fragments of C matrix
3687
+ *d[0] = c[0];
3688
+ *d[1] = c[1];
3689
+ *d[2] = c[2];
3690
+ *d[3] = c[3];
3691
+
3692
+ // Each sub-group is responsible for computing a fragment size of 16*8
3693
+ // elements of matrix D.
3694
+ // Each work item computes 4 elements of matrix D by gathering
3695
+ // their corresponding row & col matrix fragments of length k (32)
3696
+ // from A & B matrices respectively using below mapping logic:
3697
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3698
+ // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
3699
+ // & 0x3) As each row & col fragment of A & B matrices is distributed
3700
+ // across 4 work items, each iteration of below loop loads a partial
3701
+ // fragment of matrix A (row) and matrix B (col) using the row & col
3702
+ // offsets.
3703
+ for (int i = 0; i < 4; i++) {
3704
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3705
+
3706
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3707
+ recv_a[0] =
3708
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3709
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
3710
+ recv_a[1] =
3711
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3712
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3713
+ recv_b[0] =
3714
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3715
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
3716
+ recv_b[1] =
3717
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3718
+
3719
+ auto a = reinterpret_cast<ABType*>(recv_a);
3720
+ auto b = reinterpret_cast<ABType*>(recv_b);
3721
+
3722
+ // Each work item calculates a partial product of A & B matrix
3723
+ // fragments and adds it to the corresponding D matrix fragment d0
3724
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3725
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
3726
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
3727
+ // col1{ b0, b1, b2, b3 }
3728
+ for (int j = 0; j < 4; j++) {
3729
+ *d[0] += a[j] * b[j];
3730
+ *d[1] += a[j] * b[j + 4];
3731
+ *d[2] += a[j + 4] * b[j];
3732
+ *d[3] += a[j + 4] * b[j + 4];
3733
+ }
3734
+ }
3735
+
3736
+ for (int i = 0; i < 4; i++) {
3737
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3738
+
3739
+ // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
3740
+ recv_a[0] =
3741
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
3742
+ // Load partial fragment from row1 of matrix A ({a12, a13, a14,
3743
+ // a15})
3744
+ recv_a[1] =
3745
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
3746
+ // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
3747
+ recv_b[0] =
3748
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3749
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
3750
+ recv_b[1] =
3751
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
3752
+
3753
+ auto a = reinterpret_cast<ABType*>(recv_a);
3754
+ auto b = reinterpret_cast<ABType*>(recv_b);
3755
+
3756
+ // Each work item calculates a partial product of A & B matrix
3757
+ // fragments and adds it to the corresponding D matrix fragment d0
3758
+ // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
3759
+ // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
3760
+ // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
3761
+ // a15 } * col1{ b4, b5, b6, b7 }
3762
+ for (int j = 0; j < 4; j++) {
3763
+ *d[0] += a[j] * b[j];
3764
+ *d[1] += a[j] * b[j + 4];
3765
+ *d[2] += a[j + 4] * b[j];
3766
+ *d[3] += a[j + 4] * b[j + 4];
3767
+ }
3768
+ }
3769
+ }
3770
+ }
3771
+ }
3028
3772
  } // COPY from DPCT head files
3029
3773
 
3030
3774
  #endif // GGML_SYCL_DPCT_HELPER_HPP