whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -0,0 +1,1432 @@
1
+ #ifdef U32_DEQUANT_HELPERS
2
+ #define SRC0_TYPE u32
3
+
4
+ fn byte_of(v: u32, b: u32) -> u32 {
5
+ return (v >> (b * 8u)) & 0xFFu;
6
+ }
7
+
8
+ fn sbyte_of(v: u32, b: u32) -> i32 {
9
+ let raw = i32((v >> (b * 8u)) & 0xFFu);
10
+ return select(raw, raw - 256, raw >= 128);
11
+ }
12
+ #endif
13
+
14
+ #ifdef VEC
15
+ #define VEC_SIZE 4u
16
+ #define SRC0_TYPE vec4<SRC0_INNER_TYPE>
17
+ #define SRC1_TYPE vec4<SRC1_INNER_TYPE>
18
+
19
+ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
20
+ return f32(dot(SRC1_TYPE(src0_val), src1_val));
21
+ }
22
+ #endif
23
+
24
+ #ifdef SCALAR
25
+ #define VEC_SIZE 1u
26
+ #define SRC0_TYPE SRC0_INNER_TYPE
27
+ #define SRC1_TYPE SRC1_INNER_TYPE
28
+
29
+ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
30
+ return f32(src0_val) * f32(src1_val);
31
+ }
32
+ #endif
33
+
34
+ #ifdef MUL_ACC_FLOAT
35
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
36
+ var acc: array<f32, OUTPUTS_PER_WG>;
37
+
38
+ let k_vec = params.k / VEC_SIZE;
39
+ let src1_idx_base_vec = src1_idx_base / VEC_SIZE;
40
+
41
+ // Each thread walks K, loads from the vector, and updates
42
+ // a small block of output rows held in registers.
43
+ for (var k = thread_id; k < k_vec; k += WG_SIZE) {
44
+ let x = src1[src1_idx_base_vec + k];
45
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
46
+ let output_row = row_base + row;
47
+ if (output_row < params.m) {
48
+ let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k;
49
+ acc[row] += inner_dot(src0[src0_idx], x);
50
+ }
51
+ }
52
+ }
53
+
54
+ return acc;
55
+ }
56
+ #endif
57
+
58
+ #ifdef MUL_ACC_Q1_0
59
+ #define BLOCK_SIZE 128
60
+ #define BLOCK_SIZE_BYTES 18
61
+ #define THREADS_PER_BLOCK 16
62
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
63
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
64
+ var acc: array<f32, OUTPUTS_PER_WG>;
65
+
66
+ let num_blocks = params.k / BLOCK_SIZE;
67
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
68
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
69
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
70
+ var x_block: array<f32, ELEMS_PER_THREAD>;
71
+ for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
72
+ x_block[i] = f32(src1[x_base + i]);
73
+ }
74
+
75
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
76
+ let output_row = row_base + row;
77
+ if (output_row < params.m) {
78
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
79
+ let d = f32(load_f16_at_src0(block_byte_base));
80
+ let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
81
+ var row_sum = 0.0;
82
+ for (var bit = 0u; bit < 8u; bit++) {
83
+ let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
84
+ row_sum += w * x_block[bit];
85
+ }
86
+ acc[row] += row_sum;
87
+ }
88
+ }
89
+ }
90
+
91
+ return acc;
92
+ }
93
+ #endif
94
+
95
+ #ifdef MUL_ACC_Q4_0
96
+ #define BLOCK_SIZE 32
97
+ #define BLOCK_SIZE_BYTES 18
98
+ #define THREADS_PER_BLOCK 4
99
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
100
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
101
+ var acc: array<f32, OUTPUTS_PER_WG>;
102
+
103
+ let num_blocks = params.k / BLOCK_SIZE;
104
+ let thread_within_block = thread_id % 4;
105
+ for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
106
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
107
+ var x_block: array<f32, ELEMS_PER_THREAD>;
108
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
109
+ x_block[i] = f32(src1[x_base + i]);
110
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
111
+ }
112
+
113
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
114
+ let output_row = row_base + row;
115
+ if (output_row < params.m) {
116
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
117
+ let d = f32(load_f16_at_src0(block_byte_base));
118
+ var row_sum = 0.0;
119
+
120
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
121
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
122
+ let q_byte = get_byte(q_packed, byte_idx);
123
+ let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
124
+ let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d;
125
+ row_sum += q_lo * x_block[byte_idx];
126
+ row_sum += q_hi * x_block[byte_idx + 4u];
127
+ }
128
+ acc[row] += row_sum;
129
+ }
130
+ }
131
+ }
132
+
133
+ return acc;
134
+ }
135
+ #endif
136
+
137
+ #ifdef MUL_ACC_Q4_1
138
+ #define BLOCK_SIZE 32
139
+ #define BLOCK_SIZE_BYTES 20
140
+ #define THREADS_PER_BLOCK 4
141
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
142
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
143
+ var acc: array<f32, OUTPUTS_PER_WG>;
144
+
145
+ let num_blocks = params.k / BLOCK_SIZE;
146
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
147
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
148
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
149
+ var x_block: array<f32, ELEMS_PER_THREAD>;
150
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
151
+ x_block[i] = f32(src1[x_base + i]);
152
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
153
+ }
154
+
155
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
156
+ let output_row = row_base + row;
157
+ if (output_row < params.m) {
158
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
159
+ let d = f32(load_f16_at_src0(block_byte_base));
160
+ let m = f32(load_f16_at_src0(block_byte_base + 2u));
161
+ var row_sum = 0.0;
162
+
163
+ let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block);
164
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
165
+ let q_byte = get_byte(q_packed, byte_idx);
166
+ let q_lo = f32(q_byte & 0xFu) * d + m;
167
+ let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m;
168
+ row_sum += q_lo * x_block[byte_idx];
169
+ row_sum += q_hi * x_block[byte_idx + 4u];
170
+ }
171
+ acc[row] += row_sum;
172
+ }
173
+ }
174
+ }
175
+
176
+ return acc;
177
+ }
178
+ #endif
179
+
180
+ #ifdef MUL_ACC_Q5_0
181
+ #define BLOCK_SIZE 32
182
+ #define BLOCK_SIZE_BYTES 22
183
+ #define THREADS_PER_BLOCK 4
184
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
185
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
186
+ var acc: array<f32, OUTPUTS_PER_WG>;
187
+
188
+ let num_blocks = params.k / BLOCK_SIZE;
189
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
190
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
191
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
192
+ var x_block: array<f32, ELEMS_PER_THREAD>;
193
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
194
+ x_block[i] = f32(src1[x_base + i]);
195
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
196
+ }
197
+
198
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
199
+ let output_row = row_base + row;
200
+ if (output_row < params.m) {
201
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
202
+ let d = f32(load_f16_at_src0(block_byte_base));
203
+ let qh_packed = load_u32_at_src0(block_byte_base + 2u);
204
+ let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block);
205
+ let qh_shift = thread_within_block * 4u;
206
+ var row_sum = 0.0;
207
+
208
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
209
+ let q_byte = get_byte(q_packed, byte_idx);
210
+ let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
211
+ let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
212
+ let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d;
213
+ let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d;
214
+ row_sum += q_lo * x_block[byte_idx];
215
+ row_sum += q_hi * x_block[byte_idx + 4u];
216
+ }
217
+ acc[row] += row_sum;
218
+ }
219
+ }
220
+ }
221
+
222
+ return acc;
223
+ }
224
+ #endif
225
+
226
+ #ifdef MUL_ACC_Q5_1
227
+ #define BLOCK_SIZE 32
228
+ #define BLOCK_SIZE_BYTES 24
229
+ #define THREADS_PER_BLOCK 4
230
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
231
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
232
+ var acc: array<f32, OUTPUTS_PER_WG>;
233
+
234
+ let num_blocks = params.k / BLOCK_SIZE;
235
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
236
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
237
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
238
+ var x_block: array<f32, ELEMS_PER_THREAD>;
239
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
240
+ x_block[i] = f32(src1[x_base + i]);
241
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
242
+ }
243
+
244
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
245
+ let output_row = row_base + row;
246
+ if (output_row < params.m) {
247
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
248
+ let d = f32(load_f16_at_src0(block_byte_base));
249
+ let m = f32(load_f16_at_src0(block_byte_base + 2u));
250
+ let qh_packed = load_u32_at_src0(block_byte_base + 4u);
251
+ let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block);
252
+ let qh_shift = thread_within_block * 4u;
253
+ var row_sum = 0.0;
254
+
255
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
256
+ let q_byte = get_byte(q_packed, byte_idx);
257
+ let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
258
+ let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
259
+ let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m;
260
+ let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m;
261
+ row_sum += q_lo * x_block[byte_idx];
262
+ row_sum += q_hi * x_block[byte_idx + 4u];
263
+ }
264
+ acc[row] += row_sum;
265
+ }
266
+ }
267
+ }
268
+
269
+ return acc;
270
+ }
271
+ #endif
272
+
273
+ #ifdef MUL_ACC_Q8_0
274
+ #define BLOCK_SIZE 32
275
+ #define BLOCK_SIZE_BYTES 34
276
+ #define THREADS_PER_BLOCK 4
277
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
278
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
279
+ var acc: array<f32, OUTPUTS_PER_WG>;
280
+
281
+ let num_blocks = params.k / BLOCK_SIZE;
282
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
283
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
284
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
285
+ var x_block: array<f32, ELEMS_PER_THREAD>;
286
+ for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
287
+ x_block[i] = f32(src1[x_base + i]);
288
+ }
289
+
290
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
291
+ let output_row = row_base + row;
292
+ if (output_row < params.m) {
293
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
294
+ let d = f32(load_f16_at_src0(block_byte_base));
295
+ var row_sum = 0.0;
296
+
297
+ for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
298
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx));
299
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
300
+ let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d;
301
+ row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
302
+ }
303
+ }
304
+ acc[row] += row_sum;
305
+ }
306
+ }
307
+ }
308
+
309
+ return acc;
310
+ }
311
+ #endif
312
+
313
+ #ifdef MUL_ACC_Q8_1
314
+ #define BLOCK_SIZE 32
315
+ #define BLOCK_SIZE_BYTES 36
316
+ #define THREADS_PER_BLOCK 4
317
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
318
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
319
+ var acc: array<f32, OUTPUTS_PER_WG>;
320
+
321
+ let num_blocks = params.k / BLOCK_SIZE;
322
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
323
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
324
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
325
+ var x_block: array<f32, ELEMS_PER_THREAD>;
326
+ for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
327
+ x_block[i] = f32(src1[x_base + i]);
328
+ }
329
+
330
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
331
+ let output_row = row_base + row;
332
+ if (output_row < params.m) {
333
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
334
+ let d = f32(load_f16_at_src0(block_byte_base));
335
+ let m = f32(load_f16_at_src0(block_byte_base + 2u));
336
+ var row_sum = 0.0;
337
+
338
+ for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
339
+ let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx));
340
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
341
+ let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m;
342
+ row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
343
+ }
344
+ }
345
+ acc[row] += row_sum;
346
+ }
347
+ }
348
+ }
349
+
350
+ return acc;
351
+ }
352
+ #endif
353
+
354
+ #ifdef MUL_ACC_Q2_K
355
+ #define BLOCK_SIZE 256
356
+ #define BLOCK_SIZE_BYTES 84
357
+ #define THREADS_PER_BLOCK 16
358
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
359
+ var acc: array<f32, OUTPUTS_PER_WG>;
360
+
361
+ let tid = thread_id % THREADS_PER_BLOCK;
362
+ let block_group = thread_id / THREADS_PER_BLOCK;
363
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
364
+
365
+ let lane = tid / 2u;
366
+ let phase = tid % 2u;
367
+ let iq = lane / 4u;
368
+ let ir = lane % 4u;
369
+ let is = ir / 2u;
370
+
371
+ let y_offset = 128u * iq + 8u * ir + 4u * phase;
372
+ let sc0_byte = 8u * iq + is;
373
+ let sc2_byte = 8u * iq + is + 2u;
374
+ let sc4_byte = 8u * iq + is + 4u;
375
+ let sc6_byte = 8u * iq + is + 6u;
376
+ let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase;
377
+
378
+ let num_blocks = params.k / BLOCK_SIZE;
379
+
380
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
381
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
382
+ var x_block: array<f32, 16>;
383
+ for (var i = 0u; i < 4u; i++) {
384
+ x_block[i] = f32(src1[x_base + i]);
385
+ x_block[i + 4u] = f32(src1[x_base + 32u + i]);
386
+ x_block[i + 8u] = f32(src1[x_base + 64u + i]);
387
+ x_block[i + 12u] = f32(src1[x_base + 96u + i]);
388
+ }
389
+
390
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
391
+ let output_row = row_base + row;
392
+ if (output_row < params.m) {
393
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
394
+
395
+ let dall = f32(load_f16_at_src0(block_byte_base + 80u));
396
+ let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0);
397
+
398
+ let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u);
399
+ let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u);
400
+ let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u);
401
+ let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u);
402
+
403
+ let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte);
404
+ let qs0 = q_u32 & 0xFFFFu;
405
+ let qs1 = q_u32 >> 16u;
406
+
407
+ var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0);
408
+ var acc1 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
409
+ var acc2 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
410
+
411
+ sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3];
412
+ sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7];
413
+ sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11];
414
+ sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15];
415
+
416
+ acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u);
417
+ acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u);
418
+ acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu);
419
+ acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u);
420
+ acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u);
421
+ acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u);
422
+ acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u);
423
+ acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u);
424
+
425
+ acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) +
426
+ (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 +
427
+ (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 +
428
+ (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0)
429
+ - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) +
430
+ sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u));
431
+ }
432
+ }
433
+ }
434
+
435
+ return acc;
436
+ }
437
+ #endif
438
+
439
+ #ifdef MUL_ACC_Q3_K
440
+ #define BLOCK_SIZE 256
441
+ #define BLOCK_SIZE_BYTES 110
442
+ #define THREADS_PER_BLOCK 16
443
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
444
+ var acc: array<f32, OUTPUTS_PER_WG>;
445
+
446
+ let tid = thread_id % THREADS_PER_BLOCK;
447
+ let block_group = thread_id / THREADS_PER_BLOCK;
448
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
449
+
450
+ let lane = tid / 2u;
451
+ let phase = tid % 2u;
452
+ let ip = lane / 4u;
453
+ let il = 2u * ((lane % 4u) / 2u);
454
+ let ir = lane % 2u;
455
+ let l0 = 8u * ir;
456
+
457
+ let q_byte = 32u + 32u * ip + l0 + 16u * phase;
458
+ let h_byte = l0 + 16u * phase;
459
+ let y_offset = 128u * ip + 32u * il + l0 + 16u * phase;
460
+
461
+ let s_shift1 = 4u * ip;
462
+ let s_shift2 = s_shift1 + il;
463
+
464
+ let v1 = select(64.0, 4.0, il == 0u);
465
+ let v2 = 4.0 * v1;
466
+ let shift = 2u * il;
467
+
468
+ var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32;
469
+ if (il == 0u) {
470
+ qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u;
471
+ } else {
472
+ qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u;
473
+ }
474
+
475
+ let mm_idx = 2u * ip + il / 2u;
476
+ var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32;
477
+ switch (mm_idx) {
478
+ case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; }
479
+ case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; }
480
+ case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; }
481
+ default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; }
482
+ }
483
+
484
+ let num_blocks = params.k / BLOCK_SIZE;
485
+
486
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
487
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
488
+ var x_block: array<f32, 16>;
489
+ for (var i = 0u; i < 8u; i++) {
490
+ x_block[i] = f32(src1[x_base + i]);
491
+ x_block[i + 8u] = f32(src1[x_base + 32u + i]);
492
+ }
493
+
494
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
495
+ let output_row = row_base + row;
496
+ if (output_row < params.m) {
497
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
498
+
499
+ let d = f32(load_f16_at_src0(block_byte_base + 108u));
500
+ let a_base = 96u;
501
+ let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u);
502
+ let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u);
503
+ let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u);
504
+ let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u);
505
+
506
+ var scales32 = a_4 | (a_5 << 16u);
507
+ let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u;
508
+ scales32 = a_il0 | (a_il1 << 16u);
509
+ scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32;
510
+
511
+ let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32);
512
+ let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32);
513
+
514
+ let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u);
515
+ let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u);
516
+ let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u);
517
+ let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u);
518
+
519
+ var s1 = 0.0; var s2 = 0.0; var s3 = 0.0;
520
+ var s4 = 0.0; var s5 = 0.0; var s6 = 0.0;
521
+
522
+ for (var l = 0u; l < 8u; l += 2u) {
523
+ let q_u32 = select(q_u32_0, q_u32_1, l >= 4u);
524
+ let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u);
525
+ let h_u32 = select(h_u32_0, h_u32_1, l >= 4u);
526
+ let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u);
527
+
528
+ s1 += x_block[l + 0u] * f32(qs & qm0);
529
+ s2 += x_block[l + 1u] * f32(qs & qm1);
530
+ s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) +
531
+ select(0.0, x_block[l + 1u], (hv & hm1) == 0u);
532
+ s4 += x_block[l + 8u] * f32(qs & qm2);
533
+ s5 += x_block[l + 9u] * f32(qs & qm3);
534
+ s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) +
535
+ select(0.0, x_block[l + 9u], (hv & hm3) == 0u);
536
+ }
537
+
538
+ let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1);
539
+ let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2);
540
+ acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift);
541
+ }
542
+ }
543
+ }
544
+
545
+ return acc;
546
+ }
547
+ #endif
548
+
549
+ #ifdef MUL_ACC_Q4_K
550
+ #define BLOCK_SIZE 256
551
+ #define BLOCK_SIZE_BYTES 144
552
+ #define THREADS_PER_BLOCK 16
553
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
554
+ var acc: array<f32, OUTPUTS_PER_WG>;
555
+
556
+ let tid = thread_id % THREADS_PER_BLOCK;
557
+ let block_group = thread_id / THREADS_PER_BLOCK;
558
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
559
+
560
+ let il = tid / 4u;
561
+ let ir = tid % 4u;
562
+ let im = il / 2u;
563
+ let in = il % 2u;
564
+ let l0 = 4u * (2u * ir + in);
565
+
566
+ let y_offset = 64u * im + l0;
567
+ let q_offset = 32u * im + l0;
568
+ let sc0_byte = 4u + im * 2u;
569
+ let sc2_byte = 4u + (im + 2u) * 2u;
570
+ let sc4_byte = 4u + (im + 4u) * 2u;
571
+
572
+ let num_blocks = params.k / BLOCK_SIZE;
573
+
574
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
575
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
576
+ var x_block: array<f32, 16>;
577
+ for (var i = 0u; i < 4u; i++) {
578
+ x_block[i] = f32(src1[x_base + i]);
579
+ x_block[i + 4u] = f32(src1[x_base + 32u + i]);
580
+ x_block[i + 8u] = f32(src1[x_base + 128u + i]);
581
+ x_block[i + 12u] = f32(src1[x_base + 160u + i]);
582
+ }
583
+
584
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
585
+ let output_row = row_base + row;
586
+ if (output_row < params.m) {
587
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
588
+
589
+ let d = f32(load_f16_at_src0(block_byte_base + 0u));
590
+ let dmin = f32(load_f16_at_src0(block_byte_base + 2u));
591
+
592
+ let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte);
593
+ let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u);
594
+ let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte);
595
+ let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u);
596
+ let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte);
597
+ let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u);
598
+
599
+ let sc16_0 = sc0 & 0x3F3Fu;
600
+ let sc16_1 = sc2 & 0x3F3Fu;
601
+ let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u);
602
+ let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u);
603
+
604
+ let scale0 = f32(sc16_0 & 0xFFu);
605
+ let scale1 = f32((sc16_0 >> 8u) & 0xFFu);
606
+ let min0 = f32(sc16_1 & 0xFFu);
607
+ let min1 = f32((sc16_1 >> 8u) & 0xFFu);
608
+ let scale2 = f32(sc16_2 & 0xFFu);
609
+ let scale3 = f32((sc16_2 >> 8u) & 0xFFu);
610
+ let min2 = f32(sc16_3 & 0xFFu);
611
+ let min3 = f32((sc16_3 >> 8u) & 0xFFu);
612
+
613
+ let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset);
614
+ let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset);
615
+
616
+ var dot = vec4<f32>(0.0, 0.0, 0.0, 0.0);
617
+ var sumx = vec4<f32>(0.0, 0.0, 0.0, 0.0);
618
+ for (var i = 0u; i < 4u; i++) {
619
+ let q1b = byte_of(q1_u32, i);
620
+ let q2b = byte_of(q2_u32, i);
621
+ dot[0] += x_block[i] * f32(q1b & 0x0Fu);
622
+ dot[1] += x_block[i + 4u] * f32(q1b >> 4u);
623
+ dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu);
624
+ dot[3] += x_block[i + 12u] * f32(q2b >> 4u);
625
+ sumx[0] += x_block[i];
626
+ sumx[1] += x_block[i + 4u];
627
+ sumx[2] += x_block[i + 8u];
628
+ sumx[3] += x_block[i + 12u];
629
+ }
630
+
631
+ acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3)
632
+ - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3);
633
+ }
634
+ }
635
+ }
636
+
637
+ return acc;
638
+ }
639
+ #endif
640
+
641
+ #ifdef MUL_ACC_Q5_K
642
+ #define BLOCK_SIZE 256
643
+ #define BLOCK_SIZE_BYTES 176
644
+ #define THREADS_PER_BLOCK 16
645
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
646
+ var acc: array<f32, OUTPUTS_PER_WG>;
647
+
648
+ let tid = thread_id % THREADS_PER_BLOCK;
649
+ let block_group = thread_id / THREADS_PER_BLOCK;
650
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
651
+
652
+ let il = tid / 4u;
653
+ let ir = tid % 4u;
654
+ let im = il / 2u;
655
+ let in = il % 2u;
656
+ let l0 = 4u * (2u * ir + in);
657
+
658
+ let y_offset = 64u * im + l0;
659
+ let q_offset = 48u + 32u * im + l0;
660
+ let qh_offset = 16u + 8u * ir + 4u * in;
661
+ let sc0_byte = 4u + im * 2u;
662
+ let sc2_byte = 4u + (im + 2u) * 2u;
663
+ let sc4_byte = 4u + (im + 4u) * 2u;
664
+
665
+ let hm1 = 1u << (2u * im);
666
+ let hm2 = hm1 << 1u;
667
+ let hm3 = hm1 << 4u;
668
+ let hm4 = hm2 << 4u;
669
+
670
+ let num_blocks = params.k / BLOCK_SIZE;
671
+
672
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
673
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
674
+ var x_block: array<f32, 16>;
675
+ for (var i = 0u; i < 4u; i++) {
676
+ x_block[i] = f32(src1[x_base + i]);
677
+ x_block[i + 4u] = f32(src1[x_base + 32u + i]);
678
+ x_block[i + 8u] = f32(src1[x_base + 128u + i]);
679
+ x_block[i + 12u] = f32(src1[x_base + 160u + i]);
680
+ }
681
+
682
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
683
+ let output_row = row_base + row;
684
+ if (output_row < params.m) {
685
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
686
+
687
+ let d = f32(load_f16_at_src0(block_byte_base + 0u));
688
+ let dmin = f32(load_f16_at_src0(block_byte_base + 2u));
689
+
690
+ let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte);
691
+ let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u);
692
+ let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte);
693
+ let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u);
694
+ let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte);
695
+ let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u);
696
+
697
+ let sc16_0 = sc0 & 0x3F3Fu;
698
+ let sc16_1 = sc2 & 0x3F3Fu;
699
+ let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u);
700
+ let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u);
701
+
702
+ let f0 = f32(sc16_0 & 0xFFu);
703
+ let f1 = f32((sc16_0 >> 8u) & 0xFFu);
704
+ let m0 = f32(sc16_1 & 0xFFu);
705
+ let m1 = f32((sc16_1 >> 8u) & 0xFFu);
706
+ let f4 = f32(sc16_2 & 0xFFu);
707
+ let f5 = f32((sc16_2 >> 8u) & 0xFFu);
708
+ let m4 = f32(sc16_3 & 0xFFu);
709
+ let m5 = f32((sc16_3 >> 8u) & 0xFFu);
710
+
711
+ let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset);
712
+ let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u);
713
+ let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset);
714
+
715
+ var vals = vec4<f32>(0.0, 0.0, 0.0, 0.0);
716
+ var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0);
717
+ for (var i = 0u; i < 4u; i++) {
718
+ let q1b = byte_of(q1_u32, i);
719
+ let q2b = byte_of(q2_u32, i);
720
+ let qhb = byte_of(qh_u32, i);
721
+
722
+ let yl0 = x_block[i];
723
+ let yl8 = x_block[i + 4u];
724
+ let yh0 = x_block[i + 8u];
725
+ let yh8 = x_block[i + 12u];
726
+
727
+ sumy[0] += yl0;
728
+ sumy[1] += yl8;
729
+ sumy[2] += yh0;
730
+ sumy[3] += yh8;
731
+
732
+ let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u));
733
+ let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u));
734
+ let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u));
735
+ let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u));
736
+
737
+ vals[0] += yl0 * q0;
738
+ vals[1] += yl8 * q1;
739
+ vals[2] += yh0 * q2;
740
+ vals[3] += yh8 * q3;
741
+ }
742
+
743
+ acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3])
744
+ - dmin * (sumy[0] * m0 + sumy[1] * m1 +
745
+ sumy[2] * m4 + sumy[3] * m5);
746
+ }
747
+ }
748
+ }
749
+
750
+ return acc;
751
+ }
752
+ #endif
753
+
754
+ #ifdef MUL_ACC_Q6_K
755
+ #define BLOCK_SIZE 256
756
+ #define BLOCK_SIZE_BYTES 210
757
+ #define THREADS_PER_BLOCK 16
758
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
759
+ var acc: array<f32, OUTPUTS_PER_WG>;
760
+
761
+ let tid = thread_id % THREADS_PER_BLOCK;
762
+ let block_group = thread_id / THREADS_PER_BLOCK;
763
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
764
+
765
+ let ip = tid / 8u;
766
+ let il = tid % 8u;
767
+ let l0 = 4u * il;
768
+ let is = 8u * ip + l0 / 16u;
769
+
770
+ let y_offset = 128u * ip + l0;
771
+ let q_offset_l = 64u * ip + l0;
772
+ let q_offset_h = 32u * ip + l0;
773
+
774
+ let num_blocks = params.k / BLOCK_SIZE;
775
+ let sc_base_byte = 192u + (is & ~3u);
776
+ let sc_byte_pos = is & 3u;
777
+
778
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
779
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
780
+ var x_block: array<f32, 16>;
781
+ for (var l = 0u; l < 4u; l++) {
782
+ x_block[l] = f32(src1[x_base + l]);
783
+ x_block[l + 4u] = f32(src1[x_base + 32u + l]);
784
+ x_block[l + 8u] = f32(src1[x_base + 64u + l]);
785
+ x_block[l + 12u] = f32(src1[x_base + 96u + l]);
786
+ }
787
+
788
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
789
+ let output_row = row_base + row;
790
+ if (output_row < params.m) {
791
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
792
+
793
+ let d = f32(load_f16_at_src0(block_byte_base + 208u));
794
+ let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l);
795
+ let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u);
796
+ let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h);
797
+ let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte);
798
+ let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u);
799
+
800
+ let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
801
+ let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
802
+ let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
803
+ let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
804
+
805
+ var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
806
+
807
+ for (var l = 0u; l < 4u; l++) {
808
+ let q1b = byte_of(ql1_u32, l);
809
+ let q2b = byte_of(ql2_u32, l);
810
+ let qhb = byte_of(qh_u32, l);
811
+
812
+ let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
813
+ let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
814
+ let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32);
815
+ let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
816
+
817
+ sums[0] += x_block[l] * dq0;
818
+ sums[1] += x_block[l + 4u] * dq1;
819
+ sums[2] += x_block[l + 8u] * dq2;
820
+ sums[3] += x_block[l + 12u] * dq3;
821
+ }
822
+
823
+ acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
824
+ sums[2] * f32(sc4) + sums[3] * f32(sc6));
825
+ }
826
+ }
827
+ }
828
+
829
+ return acc;
830
+ }
831
+ #endif
832
+
833
+ #ifdef MUL_ACC_IQ1_S
834
+ #define BLOCK_SIZE 256
835
+ #define BLOCK_SIZE_BYTES 50
836
+ #define THREADS_PER_BLOCK 16
837
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
838
+ var acc: array<f32, OUTPUTS_PER_WG>;
839
+
840
+ let tid = thread_id % THREADS_PER_BLOCK;
841
+ let block_group = thread_id / THREADS_PER_BLOCK;
842
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
843
+
844
+ let sub_blk = tid / 2u;
845
+ let half = tid % 2u;
846
+ let slot0 = half * 2u;
847
+ let y_offset = sub_blk * 32u + slot0 * 8u;
848
+
849
+ let num_blocks = params.k / BLOCK_SIZE;
850
+
851
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
852
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
853
+ var x_block: array<f32, 16>;
854
+ for (var i = 0u; i < 16u; i++) {
855
+ x_block[i] = f32(src1[x_base + i]);
856
+ }
857
+
858
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
859
+ let output_row = row_base + row;
860
+ if (output_row < params.m) {
861
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
862
+
863
+ let d = f32(load_f16_at_src0(block_byte_base));
864
+ let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
865
+ let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
866
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
867
+ let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
868
+
869
+ var row_sum = 0.0;
870
+ for (var ll = 0u; ll < 2u; ll++) {
871
+ let l = slot0 + ll;
872
+ let qs_byte = get_byte(qs_w, l);
873
+ let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
874
+ let gw = iq1_grid[ig / 16u];
875
+ let bit_base = (ig % 16u) * 2u;
876
+ for (var j = 0u; j < 8u; j++) {
877
+ let g = (gw >> (bit_base + j * 2u)) & 3u;
878
+ let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
879
+ row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
880
+ }
881
+ }
882
+ acc[row] += row_sum;
883
+ }
884
+ }
885
+ }
886
+
887
+ return acc;
888
+ }
889
+ #endif
890
+
891
+ #ifdef MUL_ACC_IQ1_M
892
+ #define BLOCK_SIZE 256
893
+ #define BLOCK_SIZE_BYTES 56
894
+ #define THREADS_PER_BLOCK 16
895
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
896
+ var acc: array<f32, OUTPUTS_PER_WG>;
897
+
898
+ let tid = thread_id % THREADS_PER_BLOCK;
899
+ let block_group = thread_id / THREADS_PER_BLOCK;
900
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
901
+
902
+ let sub_blk = tid / 2u;
903
+ let half = tid % 2u;
904
+ let slot0 = half * 2u;
905
+ let y_offset = sub_blk * 32u + slot0 * 8u;
906
+
907
+ let num_blocks = params.k / BLOCK_SIZE;
908
+
909
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
910
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
911
+ var x_block: array<f32, 16>;
912
+ for (var i = 0u; i < 16u; i++) {
913
+ x_block[i] = f32(src1[x_base + i]);
914
+ }
915
+
916
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
917
+ let output_row = row_base + row;
918
+ if (output_row < params.m) {
919
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
920
+
921
+ let sc_lo = load_u32_at_src0(block_byte_base + 48u);
922
+ let sc_hi = load_u32_at_src0(block_byte_base + 52u);
923
+ let sc0 = sc_lo & 0xFFFFu;
924
+ let sc1 = (sc_lo >> 16u) & 0xFFFFu;
925
+ let sc2 = sc_hi & 0xFFFFu;
926
+ let sc3 = (sc_hi >> 16u) & 0xFFFFu;
927
+ let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
928
+ let d = f32(bitcast<vec2<f16>>(d_bits)[0]);
929
+
930
+ let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
931
+ select(sc0, sc1, sub_blk >= 2u),
932
+ sub_blk < 4u);
933
+
934
+ let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
935
+ let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
936
+ let qh_lo = qh & 0xFFu;
937
+ let qh_hi = (qh >> 8u) & 0xFFu;
938
+
939
+ var row_sum = 0.0;
940
+ for (var ll = 0u; ll < 2u; ll++) {
941
+ let l = slot0 + ll;
942
+ let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
943
+ let sub_scale = (sc_u16 >> bit_off) & 0x7u;
944
+ let dl = d * f32(2u * sub_scale + 1u);
945
+ let qh_byte = select(qh_lo, qh_hi, l >= 2u);
946
+ let ll2 = l % 2u;
947
+ let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
948
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
949
+ let ig = grid_idx * 8u;
950
+ let gw = iq1_grid[ig / 16u];
951
+ let bit_base = (ig % 16u) * 2u;
952
+ for (var j = 0u; j < 8u; j++) {
953
+ let g = (gw >> (bit_base + j * 2u)) & 3u;
954
+ let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
955
+ row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
956
+ }
957
+ }
958
+ acc[row] += row_sum;
959
+ }
960
+ }
961
+ }
962
+
963
+ return acc;
964
+ }
965
+ #endif
966
+
967
+ #ifdef MUL_ACC_IQ2_XXS
968
+ #define BLOCK_SIZE 256
969
+ #define BLOCK_SIZE_BYTES 66
970
+ #define THREADS_PER_BLOCK 16
971
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
972
+ var acc: array<f32, OUTPUTS_PER_WG>;
973
+
974
+ let tid = thread_id % THREADS_PER_BLOCK;
975
+ let block_group = thread_id / THREADS_PER_BLOCK;
976
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
977
+
978
+ let sub_blk = tid / 2u;
979
+ let half = tid % 2u;
980
+ let slot0 = half * 2u;
981
+ let y_offset = sub_blk * 32u + slot0 * 8u;
982
+
983
+ let num_blocks = params.k / BLOCK_SIZE;
984
+
985
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
986
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
987
+ var x_block: array<f32, 16>;
988
+ for (var i = 0u; i < 16u; i++) {
989
+ x_block[i] = f32(src1[x_base + i]);
990
+ }
991
+
992
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
993
+ let output_row = row_base + row;
994
+ if (output_row < params.m) {
995
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
996
+ let d = f32(load_f16_at_src0(block_byte_base));
997
+ let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
998
+ let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
999
+ let ls = aux_hi >> 28u;
1000
+ let db = d * (0.5 + f32(ls)) * 0.25;
1001
+
1002
+ var row_sum = 0.0;
1003
+ for (var ll = 0u; ll < 2u; ll++) {
1004
+ let l = slot0 + ll;
1005
+ let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
1006
+ let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
1007
+ let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
1008
+ let gw_lo = iq2xxs_grid[grid_idx * 2u];
1009
+ let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
1010
+ for (var j = 0u; j < 8u; j++) {
1011
+ let gw = select(gw_hi, gw_lo, j < 4u);
1012
+ let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
1013
+ let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
1014
+ row_sum += db * b * s * x_block[ll * 8u + j];
1015
+ }
1016
+ }
1017
+ acc[row] += row_sum;
1018
+ }
1019
+ }
1020
+ }
1021
+
1022
+ return acc;
1023
+ }
1024
+ #endif
1025
+
1026
+ #ifdef MUL_ACC_IQ2_XS
1027
+ #define BLOCK_SIZE 256
1028
+ #define BLOCK_SIZE_BYTES 74
1029
+ #define THREADS_PER_BLOCK 16
1030
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1031
+ var acc: array<f32, OUTPUTS_PER_WG>;
1032
+
1033
+ let tid = thread_id % THREADS_PER_BLOCK;
1034
+ let block_group = thread_id / THREADS_PER_BLOCK;
1035
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1036
+
1037
+ let sub_blk = tid / 2u;
1038
+ let half = tid % 2u;
1039
+ let slot0 = half * 2u;
1040
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1041
+
1042
+ let num_blocks = params.k / BLOCK_SIZE;
1043
+
1044
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1045
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1046
+ var x_block: array<f32, 16>;
1047
+ for (var i = 0u; i < 16u; i++) {
1048
+ x_block[i] = f32(src1[x_base + i]);
1049
+ }
1050
+
1051
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1052
+ let output_row = row_base + row;
1053
+ if (output_row < params.m) {
1054
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1055
+ let d = f32(load_f16_at_src0(block_byte_base));
1056
+ let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
1057
+ let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
1058
+ let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
1059
+ let scales_byte = get_byte(scales_word, sub_blk % 4u);
1060
+
1061
+ var row_sum = 0.0;
1062
+ for (var ll = 0u; ll < 2u; ll++) {
1063
+ let l = slot0 + ll;
1064
+ let qs_word = select(qs_hi, qs_lo, l < 2u);
1065
+ let half2 = (l % 2u) * 16u;
1066
+ let qs_val = (qs_word >> half2) & 0xFFFFu;
1067
+ let grid_idx = qs_val & 0x1FFu;
1068
+ let signs_idx = (qs_val >> 9u) & 0x7Fu;
1069
+ let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
1070
+ let db = d * (0.5 + f32(sub_scale)) * 0.25;
1071
+ let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
1072
+ let gw_lo = iq2xs_grid[grid_idx * 2u];
1073
+ let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
1074
+ for (var j = 0u; j < 8u; j++) {
1075
+ let gw = select(gw_hi, gw_lo, j < 4u);
1076
+ let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
1077
+ let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
1078
+ row_sum += db * b * s * x_block[ll * 8u + j];
1079
+ }
1080
+ }
1081
+ acc[row] += row_sum;
1082
+ }
1083
+ }
1084
+ }
1085
+
1086
+ return acc;
1087
+ }
1088
+ #endif
1089
+
1090
+ #ifdef MUL_ACC_IQ2_S
1091
+ #define BLOCK_SIZE 256
1092
+ #define BLOCK_SIZE_BYTES 82
1093
+ #define THREADS_PER_BLOCK 16
1094
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1095
+ var acc: array<f32, OUTPUTS_PER_WG>;
1096
+
1097
+ let tid = thread_id % THREADS_PER_BLOCK;
1098
+ let block_group = thread_id / THREADS_PER_BLOCK;
1099
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1100
+
1101
+ let sub_blk = tid / 2u;
1102
+ let half = tid % 2u;
1103
+ let slot0 = half * 2u;
1104
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1105
+
1106
+ let num_blocks = params.k / BLOCK_SIZE;
1107
+
1108
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1109
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1110
+ var x_block: array<f32, 16>;
1111
+ for (var i = 0u; i < 16u; i++) {
1112
+ x_block[i] = f32(src1[x_base + i]);
1113
+ }
1114
+
1115
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1116
+ let output_row = row_base + row;
1117
+ if (output_row < params.m) {
1118
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1119
+ let d = f32(load_f16_at_src0(block_byte_base));
1120
+ let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
1121
+ let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
1122
+ let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
1123
+ let qh_byte = get_byte(qh_word, sub_blk % 4u);
1124
+ let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
1125
+ let scales_byte = get_byte(sc_word, sub_blk % 4u);
1126
+
1127
+ var row_sum = 0.0;
1128
+ for (var ll = 0u; ll < 2u; ll++) {
1129
+ let l = slot0 + ll;
1130
+ let qs_byte = get_byte(qs_w, l);
1131
+ let sign_byte = get_byte(sg_w, l);
1132
+ let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
1133
+ let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
1134
+ let db = d * (0.5 + f32(sub_scale)) * 0.25;
1135
+ let gw_lo = iq2s_grid[grid_idx * 2u];
1136
+ let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
1137
+ for (var j = 0u; j < 8u; j++) {
1138
+ let gw = select(gw_hi, gw_lo, j < 4u);
1139
+ let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
1140
+ let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
1141
+ row_sum += db * b * s * x_block[ll * 8u + j];
1142
+ }
1143
+ }
1144
+ acc[row] += row_sum;
1145
+ }
1146
+ }
1147
+ }
1148
+
1149
+ return acc;
1150
+ }
1151
+ #endif
1152
+
1153
+ #ifdef MUL_ACC_IQ3_XXS
1154
+ #define BLOCK_SIZE 256
1155
+ #define BLOCK_SIZE_BYTES 98
1156
+ #define THREADS_PER_BLOCK 16
1157
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1158
+ var acc: array<f32, OUTPUTS_PER_WG>;
1159
+
1160
+ let tid = thread_id % THREADS_PER_BLOCK;
1161
+ let block_group = thread_id / THREADS_PER_BLOCK;
1162
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1163
+
1164
+ let sub_blk = tid / 2u;
1165
+ let half = tid % 2u;
1166
+ let slot0 = half * 2u;
1167
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1168
+
1169
+ let num_blocks = params.k / BLOCK_SIZE;
1170
+
1171
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1172
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1173
+ var x_block: array<f32, 16>;
1174
+ for (var i = 0u; i < 16u; i++) {
1175
+ x_block[i] = f32(src1[x_base + i]);
1176
+ }
1177
+
1178
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1179
+ let output_row = row_base + row;
1180
+ if (output_row < params.m) {
1181
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1182
+ let d = f32(load_f16_at_src0(block_byte_base));
1183
+ let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
1184
+ let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
1185
+ let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
1186
+ let ls = aux >> 28u;
1187
+ let db = d * (0.5 + f32(ls)) * 0.5;
1188
+
1189
+ var row_sum = 0.0;
1190
+ for (var ll = 0u; ll < 2u; ll++) {
1191
+ let l = slot0 + ll;
1192
+ let qs_word = select(qs_hi, qs_lo, l < 2u);
1193
+ let byte_pos = (l % 2u) * 2u;
1194
+ let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
1195
+ let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
1196
+ let signs_idx = (aux >> (7u * l)) & 0x7Fu;
1197
+ let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
1198
+ let grid1 = iq3xxs_grid[grid_idx_0];
1199
+ let grid2 = iq3xxs_grid[grid_idx_1];
1200
+ for (var j = 0u; j < 4u; j++) {
1201
+ let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
1202
+ let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
1203
+ let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
1204
+ let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
1205
+ row_sum += db * b1 * s1 * x_block[ll * 8u + j];
1206
+ row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
1207
+ }
1208
+ }
1209
+ acc[row] += row_sum;
1210
+ }
1211
+ }
1212
+ }
1213
+
1214
+ return acc;
1215
+ }
1216
+ #endif
1217
+
1218
+ #ifdef MUL_ACC_IQ3_S
1219
+ #define BLOCK_SIZE 256
1220
+ #define BLOCK_SIZE_BYTES 110
1221
+ #define THREADS_PER_BLOCK 16
1222
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1223
+ var acc: array<f32, OUTPUTS_PER_WG>;
1224
+
1225
+ let tid = thread_id % THREADS_PER_BLOCK;
1226
+ let block_group = thread_id / THREADS_PER_BLOCK;
1227
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1228
+
1229
+ let sub_blk = tid / 2u;
1230
+ let half = tid % 2u;
1231
+ let slot0 = half * 2u;
1232
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1233
+
1234
+ let num_blocks = params.k / BLOCK_SIZE;
1235
+
1236
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1237
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1238
+ var x_block: array<f32, 16>;
1239
+ for (var i = 0u; i < 16u; i++) {
1240
+ x_block[i] = f32(src1[x_base + i]);
1241
+ }
1242
+
1243
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1244
+ let output_row = row_base + row;
1245
+ if (output_row < params.m) {
1246
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1247
+ let d = f32(load_f16_at_src0(block_byte_base));
1248
+ let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
1249
+ let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
1250
+ let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
1251
+ let qh_byte = get_byte(qh_word, sub_blk % 4u);
1252
+ let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
1253
+ let sc_word = load_u32_at_src0(block_byte_base + 106u);
1254
+ let scales_byte = get_byte(sc_word, sub_blk / 2u);
1255
+ let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
1256
+ let db = d * (1.0 + 2.0 * f32(sub_scale));
1257
+
1258
+ var row_sum = 0.0;
1259
+ for (var ll = 0u; ll < 2u; ll++) {
1260
+ let l = slot0 + ll;
1261
+ let qs_word = select(qs_hi, qs_lo, l < 2u);
1262
+ let byte_pos = (l % 2u) * 2u;
1263
+ let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
1264
+ let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
1265
+ let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
1266
+ let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
1267
+ let sign_byte = get_byte(sg_w, l);
1268
+ let grid1 = iq3s_grid[grid_idx_1];
1269
+ let grid2 = iq3s_grid[grid_idx_2];
1270
+ for (var j = 0u; j < 4u; j++) {
1271
+ let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
1272
+ let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
1273
+ let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
1274
+ let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
1275
+ row_sum += db * b1 * s1 * x_block[ll * 8u + j];
1276
+ row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
1277
+ }
1278
+ }
1279
+ acc[row] += row_sum;
1280
+ }
1281
+ }
1282
+ }
1283
+
1284
+ return acc;
1285
+ }
1286
+ #endif
1287
+
1288
+ #ifdef MUL_ACC_IQ4_NL
1289
+ #define BLOCK_SIZE 32
1290
+ #define BLOCK_SIZE_BYTES 18
1291
+ #define THREADS_PER_BLOCK 4
1292
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
1293
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1294
+ var acc: array<f32, OUTPUTS_PER_WG>;
1295
+
1296
+ let num_blocks = params.k / BLOCK_SIZE;
1297
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
1298
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
1299
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
1300
+ var x_block: array<f32, ELEMS_PER_THREAD>;
1301
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
1302
+ x_block[i] = f32(src1[x_base + i]);
1303
+ x_block[i + 4u] = f32(src1[x_base + i + 16u]);
1304
+ }
1305
+
1306
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1307
+ let output_row = row_base + row;
1308
+ if (output_row < params.m) {
1309
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1310
+ let d = f32(load_f16_at_src0(block_byte_base));
1311
+ var row_sum = 0.0;
1312
+
1313
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
1314
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
1315
+ let q_byte = get_byte(q_packed, byte_idx);
1316
+ let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
1317
+ let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
1318
+ row_sum += q_lo * x_block[byte_idx];
1319
+ row_sum += q_hi * x_block[byte_idx + 4u];
1320
+ }
1321
+ acc[row] += row_sum;
1322
+ }
1323
+ }
1324
+ }
1325
+
1326
+ return acc;
1327
+ }
1328
+ #endif
1329
+
1330
+ #ifdef MUL_ACC_IQ4_XS
1331
+ #define BLOCK_SIZE 256
1332
+ #define BLOCK_SIZE_BYTES 136
1333
+ #define THREADS_PER_BLOCK 16
1334
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1335
+ var acc: array<f32, OUTPUTS_PER_WG>;
1336
+
1337
+ let tid = thread_id % THREADS_PER_BLOCK;
1338
+ let block_group = thread_id / THREADS_PER_BLOCK;
1339
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1340
+
1341
+ let sub_blk = tid / 2u;
1342
+ let half = tid % 2u;
1343
+ let y_offset = sub_blk * 32u + half * 16u;
1344
+
1345
+ let num_blocks = params.k / BLOCK_SIZE;
1346
+
1347
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1348
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1349
+ var x_block: array<f32, 16>;
1350
+ for (var i = 0u; i < 16u; i++) {
1351
+ x_block[i] = f32(src1[x_base + i]);
1352
+ }
1353
+
1354
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1355
+ let output_row = row_base + row;
1356
+ if (output_row < params.m) {
1357
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1358
+ let d = f32(load_f16_at_src0(block_byte_base));
1359
+ let scales_h = load_u16_at_src0(block_byte_base + 2u);
1360
+ let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
1361
+ let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
1362
+ let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
1363
+ let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
1364
+ let ls = i32(sl | (sh_bits << 4u));
1365
+ let dl = d * f32(ls - 32);
1366
+
1367
+ let qs_byte_off = 8u + sub_blk * 16u;
1368
+ let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
1369
+ let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
1370
+ let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
1371
+ let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);
1372
+
1373
+ var row_sum = 0.0;
1374
+ for (var i = 0u; i < 16u; i++) {
1375
+ let q_word = select(
1376
+ select(q_w0, q_w1, i >= 4u),
1377
+ select(q_w2, q_w3, i >= 12u),
1378
+ i >= 8u);
1379
+ let q_byte = get_byte(q_word, i % 4u);
1380
+ let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
1381
+ row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
1382
+ }
1383
+ acc[row] += row_sum;
1384
+ }
1385
+ }
1386
+ }
1387
+
1388
+ return acc;
1389
+ }
1390
+ #endif
1391
+
1392
+ #ifdef MUL_ACC_MXFP4
1393
+ #define BLOCK_SIZE 32
1394
+ #define BLOCK_SIZE_BYTES 17
1395
+ #define THREADS_PER_BLOCK 4
1396
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
1397
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1398
+ var acc: array<f32, OUTPUTS_PER_WG>;
1399
+
1400
+ let num_blocks = params.k / BLOCK_SIZE;
1401
+ let thread_within_block = thread_id % 4;
1402
+ for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
1403
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
1404
+ var x_block: array<f32, ELEMS_PER_THREAD>;
1405
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
1406
+ x_block[i] = f32(src1[x_base + i]);
1407
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
1408
+ }
1409
+
1410
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1411
+ let output_row = row_base + row;
1412
+ if (output_row < params.m) {
1413
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1414
+ let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
1415
+ let e = ldexp(1.0, i32(eu8) - 128);
1416
+ var row_sum = 0.0;
1417
+ let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block);
1418
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
1419
+ let q_byte = get_byte(q_packed, byte_idx);
1420
+ let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e;
1421
+ let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e;
1422
+ row_sum += q_lo * x_block[byte_idx];
1423
+ row_sum += q_hi * x_block[byte_idx + 4u];
1424
+ }
1425
+ acc[row] += row_sum;
1426
+ }
1427
+ }
1428
+ }
1429
+
1430
+ return acc;
1431
+ }
1432
+ #endif