whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -0,0 +1,1432 @@
1
+ #ifdef U32_DEQUANT_HELPERS
2
+ #define SRC0_TYPE u32
3
+
4
+ fn byte_of(v: u32, b: u32) -> u32 {
5
+ return (v >> (b * 8u)) & 0xFFu;
6
+ }
7
+
8
+ fn sbyte_of(v: u32, b: u32) -> i32 {
9
+ let raw = i32((v >> (b * 8u)) & 0xFFu);
10
+ return select(raw, raw - 256, raw >= 128);
11
+ }
12
+ #endif
13
+
14
+ #ifdef VEC
15
+ #define VEC_SIZE 4u
16
+ #define SRC0_TYPE vec4<SRC0_INNER_TYPE>
17
+ #define SRC1_TYPE vec4<SRC1_INNER_TYPE>
18
+
19
+ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
20
+ return f32(dot(SRC1_TYPE(src0_val), src1_val));
21
+ }
22
+ #endif
23
+
24
+ #ifdef SCALAR
25
+ #define VEC_SIZE 1u
26
+ #define SRC0_TYPE SRC0_INNER_TYPE
27
+ #define SRC1_TYPE SRC1_INNER_TYPE
28
+
29
+ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
30
+ return f32(src0_val) * f32(src1_val);
31
+ }
32
+ #endif
33
+
34
+ #ifdef MUL_ACC_FLOAT
35
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
36
+ var acc: array<f32, OUTPUTS_PER_WG>;
37
+
38
+ let k_vec = params.k / VEC_SIZE;
39
+ let src1_idx_base_vec = src1_idx_base / VEC_SIZE;
40
+
41
+ // Each thread walks K, loads from the vector, and updates
42
+ // a small block of output rows held in registers.
43
+ for (var k = thread_id; k < k_vec; k += WG_SIZE) {
44
+ let x = src1[src1_idx_base_vec + k];
45
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
46
+ let output_row = row_base + row;
47
+ if (output_row < params.m) {
48
+ let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k;
49
+ acc[row] += inner_dot(src0[src0_idx], x);
50
+ }
51
+ }
52
+ }
53
+
54
+ return acc;
55
+ }
56
+ #endif
57
+
58
+ #ifdef MUL_ACC_Q1_0
59
+ #define BLOCK_SIZE 128
60
+ #define BLOCK_SIZE_BYTES 18
61
+ #define THREADS_PER_BLOCK 16
62
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
63
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
64
+ var acc: array<f32, OUTPUTS_PER_WG>;
65
+
66
+ let num_blocks = params.k / BLOCK_SIZE;
67
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
68
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
69
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
70
+ var x_block: array<f32, ELEMS_PER_THREAD>;
71
+ for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
72
+ x_block[i] = f32(src1[x_base + i]);
73
+ }
74
+
75
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
76
+ let output_row = row_base + row;
77
+ if (output_row < params.m) {
78
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
79
+ let d = f32(load_f16_at_src0(block_byte_base));
80
+ let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
81
+ var row_sum = 0.0;
82
+ for (var bit = 0u; bit < 8u; bit++) {
83
+ let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
84
+ row_sum += w * x_block[bit];
85
+ }
86
+ acc[row] += row_sum;
87
+ }
88
+ }
89
+ }
90
+
91
+ return acc;
92
+ }
93
+ #endif
94
+
95
+ #ifdef MUL_ACC_Q4_0
96
+ #define BLOCK_SIZE 32
97
+ #define BLOCK_SIZE_BYTES 18
98
+ #define THREADS_PER_BLOCK 4
99
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
100
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
101
+ var acc: array<f32, OUTPUTS_PER_WG>;
102
+
103
+ let num_blocks = params.k / BLOCK_SIZE;
104
+ let thread_within_block = thread_id % 4;
105
+ for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
106
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
107
+ var x_block: array<f32, ELEMS_PER_THREAD>;
108
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
109
+ x_block[i] = f32(src1[x_base + i]);
110
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
111
+ }
112
+
113
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
114
+ let output_row = row_base + row;
115
+ if (output_row < params.m) {
116
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
117
+ let d = f32(load_f16_at_src0(block_byte_base));
118
+ var row_sum = 0.0;
119
+
120
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
121
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
122
+ let q_byte = get_byte(q_packed, byte_idx);
123
+ let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
124
+ let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d;
125
+ row_sum += q_lo * x_block[byte_idx];
126
+ row_sum += q_hi * x_block[byte_idx + 4u];
127
+ }
128
+ acc[row] += row_sum;
129
+ }
130
+ }
131
+ }
132
+
133
+ return acc;
134
+ }
135
+ #endif
136
+
137
+ #ifdef MUL_ACC_Q4_1
138
+ #define BLOCK_SIZE 32
139
+ #define BLOCK_SIZE_BYTES 20
140
+ #define THREADS_PER_BLOCK 4
141
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
142
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
143
+ var acc: array<f32, OUTPUTS_PER_WG>;
144
+
145
+ let num_blocks = params.k / BLOCK_SIZE;
146
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
147
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
148
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
149
+ var x_block: array<f32, ELEMS_PER_THREAD>;
150
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
151
+ x_block[i] = f32(src1[x_base + i]);
152
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
153
+ }
154
+
155
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
156
+ let output_row = row_base + row;
157
+ if (output_row < params.m) {
158
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
159
+ let d = f32(load_f16_at_src0(block_byte_base));
160
+ let m = f32(load_f16_at_src0(block_byte_base + 2u));
161
+ var row_sum = 0.0;
162
+
163
+ let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block);
164
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
165
+ let q_byte = get_byte(q_packed, byte_idx);
166
+ let q_lo = f32(q_byte & 0xFu) * d + m;
167
+ let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m;
168
+ row_sum += q_lo * x_block[byte_idx];
169
+ row_sum += q_hi * x_block[byte_idx + 4u];
170
+ }
171
+ acc[row] += row_sum;
172
+ }
173
+ }
174
+ }
175
+
176
+ return acc;
177
+ }
178
+ #endif
179
+
180
+ #ifdef MUL_ACC_Q5_0
181
+ #define BLOCK_SIZE 32
182
+ #define BLOCK_SIZE_BYTES 22
183
+ #define THREADS_PER_BLOCK 4
184
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
185
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
186
+ var acc: array<f32, OUTPUTS_PER_WG>;
187
+
188
+ let num_blocks = params.k / BLOCK_SIZE;
189
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
190
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
191
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
192
+ var x_block: array<f32, ELEMS_PER_THREAD>;
193
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
194
+ x_block[i] = f32(src1[x_base + i]);
195
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
196
+ }
197
+
198
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
199
+ let output_row = row_base + row;
200
+ if (output_row < params.m) {
201
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
202
+ let d = f32(load_f16_at_src0(block_byte_base));
203
+ let qh_packed = load_u32_at_src0(block_byte_base + 2u);
204
+ let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block);
205
+ let qh_shift = thread_within_block * 4u;
206
+ var row_sum = 0.0;
207
+
208
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
209
+ let q_byte = get_byte(q_packed, byte_idx);
210
+ let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
211
+ let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
212
+ let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d;
213
+ let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d;
214
+ row_sum += q_lo * x_block[byte_idx];
215
+ row_sum += q_hi * x_block[byte_idx + 4u];
216
+ }
217
+ acc[row] += row_sum;
218
+ }
219
+ }
220
+ }
221
+
222
+ return acc;
223
+ }
224
+ #endif
225
+
226
+ #ifdef MUL_ACC_Q5_1
227
+ #define BLOCK_SIZE 32
228
+ #define BLOCK_SIZE_BYTES 24
229
+ #define THREADS_PER_BLOCK 4
230
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
231
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
232
+ var acc: array<f32, OUTPUTS_PER_WG>;
233
+
234
+ let num_blocks = params.k / BLOCK_SIZE;
235
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
236
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
237
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
238
+ var x_block: array<f32, ELEMS_PER_THREAD>;
239
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
240
+ x_block[i] = f32(src1[x_base + i]);
241
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
242
+ }
243
+
244
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
245
+ let output_row = row_base + row;
246
+ if (output_row < params.m) {
247
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
248
+ let d = f32(load_f16_at_src0(block_byte_base));
249
+ let m = f32(load_f16_at_src0(block_byte_base + 2u));
250
+ let qh_packed = load_u32_at_src0(block_byte_base + 4u);
251
+ let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block);
252
+ let qh_shift = thread_within_block * 4u;
253
+ var row_sum = 0.0;
254
+
255
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
256
+ let q_byte = get_byte(q_packed, byte_idx);
257
+ let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
258
+ let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
259
+ let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m;
260
+ let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m;
261
+ row_sum += q_lo * x_block[byte_idx];
262
+ row_sum += q_hi * x_block[byte_idx + 4u];
263
+ }
264
+ acc[row] += row_sum;
265
+ }
266
+ }
267
+ }
268
+
269
+ return acc;
270
+ }
271
+ #endif
272
+
273
+ #ifdef MUL_ACC_Q8_0
274
+ #define BLOCK_SIZE 32
275
+ #define BLOCK_SIZE_BYTES 34
276
+ #define THREADS_PER_BLOCK 4
277
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
278
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
279
+ var acc: array<f32, OUTPUTS_PER_WG>;
280
+
281
+ let num_blocks = params.k / BLOCK_SIZE;
282
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
283
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
284
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
285
+ var x_block: array<f32, ELEMS_PER_THREAD>;
286
+ for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
287
+ x_block[i] = f32(src1[x_base + i]);
288
+ }
289
+
290
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
291
+ let output_row = row_base + row;
292
+ if (output_row < params.m) {
293
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
294
+ let d = f32(load_f16_at_src0(block_byte_base));
295
+ var row_sum = 0.0;
296
+
297
+ for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
298
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx));
299
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
300
+ let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d;
301
+ row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
302
+ }
303
+ }
304
+ acc[row] += row_sum;
305
+ }
306
+ }
307
+ }
308
+
309
+ return acc;
310
+ }
311
+ #endif
312
+
313
+ #ifdef MUL_ACC_Q8_1
314
+ #define BLOCK_SIZE 32
315
+ #define BLOCK_SIZE_BYTES 36
316
+ #define THREADS_PER_BLOCK 4
317
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
318
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
319
+ var acc: array<f32, OUTPUTS_PER_WG>;
320
+
321
+ let num_blocks = params.k / BLOCK_SIZE;
322
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
323
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
324
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
325
+ var x_block: array<f32, ELEMS_PER_THREAD>;
326
+ for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
327
+ x_block[i] = f32(src1[x_base + i]);
328
+ }
329
+
330
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
331
+ let output_row = row_base + row;
332
+ if (output_row < params.m) {
333
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
334
+ let d = f32(load_f16_at_src0(block_byte_base));
335
+ let m = f32(load_f16_at_src0(block_byte_base + 2u));
336
+ var row_sum = 0.0;
337
+
338
+ for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
339
+ let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx));
340
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
341
+ let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m;
342
+ row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
343
+ }
344
+ }
345
+ acc[row] += row_sum;
346
+ }
347
+ }
348
+ }
349
+
350
+ return acc;
351
+ }
352
+ #endif
353
+
354
+ #ifdef MUL_ACC_Q2_K
355
+ #define BLOCK_SIZE 256
356
+ #define BLOCK_SIZE_BYTES 84
357
+ #define THREADS_PER_BLOCK 16
358
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
359
+ var acc: array<f32, OUTPUTS_PER_WG>;
360
+
361
+ let tid = thread_id % THREADS_PER_BLOCK;
362
+ let block_group = thread_id / THREADS_PER_BLOCK;
363
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
364
+
365
+ let lane = tid / 2u;
366
+ let phase = tid % 2u;
367
+ let iq = lane / 4u;
368
+ let ir = lane % 4u;
369
+ let is = ir / 2u;
370
+
371
+ let y_offset = 128u * iq + 8u * ir + 4u * phase;
372
+ let sc0_byte = 8u * iq + is;
373
+ let sc2_byte = 8u * iq + is + 2u;
374
+ let sc4_byte = 8u * iq + is + 4u;
375
+ let sc6_byte = 8u * iq + is + 6u;
376
+ let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase;
377
+
378
+ let num_blocks = params.k / BLOCK_SIZE;
379
+
380
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
381
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
382
+ var x_block: array<f32, 16>;
383
+ for (var i = 0u; i < 4u; i++) {
384
+ x_block[i] = f32(src1[x_base + i]);
385
+ x_block[i + 4u] = f32(src1[x_base + 32u + i]);
386
+ x_block[i + 8u] = f32(src1[x_base + 64u + i]);
387
+ x_block[i + 12u] = f32(src1[x_base + 96u + i]);
388
+ }
389
+
390
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
391
+ let output_row = row_base + row;
392
+ if (output_row < params.m) {
393
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
394
+
395
+ let dall = f32(load_f16_at_src0(block_byte_base + 80u));
396
+ let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0);
397
+
398
+ let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u);
399
+ let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u);
400
+ let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u);
401
+ let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u);
402
+
403
+ let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte);
404
+ let qs0 = q_u32 & 0xFFFFu;
405
+ let qs1 = q_u32 >> 16u;
406
+
407
+ var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0);
408
+ var acc1 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
409
+ var acc2 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
410
+
411
+ sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3];
412
+ sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7];
413
+ sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11];
414
+ sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15];
415
+
416
+ acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u);
417
+ acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u);
418
+ acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu);
419
+ acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u);
420
+ acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u);
421
+ acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u);
422
+ acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u);
423
+ acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u);
424
+
425
+ acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) +
426
+ (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 +
427
+ (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 +
428
+ (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0)
429
+ - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) +
430
+ sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u));
431
+ }
432
+ }
433
+ }
434
+
435
+ return acc;
436
+ }
437
+ #endif
438
+
439
+ #ifdef MUL_ACC_Q3_K
440
+ #define BLOCK_SIZE 256
441
+ #define BLOCK_SIZE_BYTES 110
442
+ #define THREADS_PER_BLOCK 16
443
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
444
+ var acc: array<f32, OUTPUTS_PER_WG>;
445
+
446
+ let tid = thread_id % THREADS_PER_BLOCK;
447
+ let block_group = thread_id / THREADS_PER_BLOCK;
448
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
449
+
450
+ let lane = tid / 2u;
451
+ let phase = tid % 2u;
452
+ let ip = lane / 4u;
453
+ let il = 2u * ((lane % 4u) / 2u);
454
+ let ir = lane % 2u;
455
+ let l0 = 8u * ir;
456
+
457
+ let q_byte = 32u + 32u * ip + l0 + 16u * phase;
458
+ let h_byte = l0 + 16u * phase;
459
+ let y_offset = 128u * ip + 32u * il + l0 + 16u * phase;
460
+
461
+ let s_shift1 = 4u * ip;
462
+ let s_shift2 = s_shift1 + il;
463
+
464
+ let v1 = select(64.0, 4.0, il == 0u);
465
+ let v2 = 4.0 * v1;
466
+ let shift = 2u * il;
467
+
468
+ var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32;
469
+ if (il == 0u) {
470
+ qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u;
471
+ } else {
472
+ qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u;
473
+ }
474
+
475
+ let mm_idx = 2u * ip + il / 2u;
476
+ var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32;
477
+ switch (mm_idx) {
478
+ case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; }
479
+ case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; }
480
+ case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; }
481
+ default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; }
482
+ }
483
+
484
+ let num_blocks = params.k / BLOCK_SIZE;
485
+
486
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
487
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
488
+ var x_block: array<f32, 16>;
489
+ for (var i = 0u; i < 8u; i++) {
490
+ x_block[i] = f32(src1[x_base + i]);
491
+ x_block[i + 8u] = f32(src1[x_base + 32u + i]);
492
+ }
493
+
494
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
495
+ let output_row = row_base + row;
496
+ if (output_row < params.m) {
497
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
498
+
499
+ let d = f32(load_f16_at_src0(block_byte_base + 108u));
500
+ let a_base = 96u;
501
+ let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u);
502
+ let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u);
503
+ let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u);
504
+ let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u);
505
+
506
+ var scales32 = a_4 | (a_5 << 16u);
507
+ let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u;
508
+ scales32 = a_il0 | (a_il1 << 16u);
509
+ scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32;
510
+
511
+ let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32);
512
+ let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32);
513
+
514
+ let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u);
515
+ let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u);
516
+ let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u);
517
+ let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u);
518
+
519
+ var s1 = 0.0; var s2 = 0.0; var s3 = 0.0;
520
+ var s4 = 0.0; var s5 = 0.0; var s6 = 0.0;
521
+
522
+ for (var l = 0u; l < 8u; l += 2u) {
523
+ let q_u32 = select(q_u32_0, q_u32_1, l >= 4u);
524
+ let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u);
525
+ let h_u32 = select(h_u32_0, h_u32_1, l >= 4u);
526
+ let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u);
527
+
528
+ s1 += x_block[l + 0u] * f32(qs & qm0);
529
+ s2 += x_block[l + 1u] * f32(qs & qm1);
530
+ s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) +
531
+ select(0.0, x_block[l + 1u], (hv & hm1) == 0u);
532
+ s4 += x_block[l + 8u] * f32(qs & qm2);
533
+ s5 += x_block[l + 9u] * f32(qs & qm3);
534
+ s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) +
535
+ select(0.0, x_block[l + 9u], (hv & hm3) == 0u);
536
+ }
537
+
538
+ let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1);
539
+ let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2);
540
+ acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift);
541
+ }
542
+ }
543
+ }
544
+
545
+ return acc;
546
+ }
547
+ #endif
548
+
549
+ #ifdef MUL_ACC_Q4_K
550
+ #define BLOCK_SIZE 256
551
+ #define BLOCK_SIZE_BYTES 144
552
+ #define THREADS_PER_BLOCK 16
553
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
554
+ var acc: array<f32, OUTPUTS_PER_WG>;
555
+
556
+ let tid = thread_id % THREADS_PER_BLOCK;
557
+ let block_group = thread_id / THREADS_PER_BLOCK;
558
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
559
+
560
+ let il = tid / 4u;
561
+ let ir = tid % 4u;
562
+ let im = il / 2u;
563
+ let in = il % 2u;
564
+ let l0 = 4u * (2u * ir + in);
565
+
566
+ let y_offset = 64u * im + l0;
567
+ let q_offset = 32u * im + l0;
568
+ let sc0_byte = 4u + im * 2u;
569
+ let sc2_byte = 4u + (im + 2u) * 2u;
570
+ let sc4_byte = 4u + (im + 4u) * 2u;
571
+
572
+ let num_blocks = params.k / BLOCK_SIZE;
573
+
574
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
575
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
576
+ var x_block: array<f32, 16>;
577
+ for (var i = 0u; i < 4u; i++) {
578
+ x_block[i] = f32(src1[x_base + i]);
579
+ x_block[i + 4u] = f32(src1[x_base + 32u + i]);
580
+ x_block[i + 8u] = f32(src1[x_base + 128u + i]);
581
+ x_block[i + 12u] = f32(src1[x_base + 160u + i]);
582
+ }
583
+
584
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
585
+ let output_row = row_base + row;
586
+ if (output_row < params.m) {
587
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
588
+
589
+ let d = f32(load_f16_at_src0(block_byte_base + 0u));
590
+ let dmin = f32(load_f16_at_src0(block_byte_base + 2u));
591
+
592
+ let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte);
593
+ let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u);
594
+ let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte);
595
+ let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u);
596
+ let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte);
597
+ let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u);
598
+
599
+ let sc16_0 = sc0 & 0x3F3Fu;
600
+ let sc16_1 = sc2 & 0x3F3Fu;
601
+ let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u);
602
+ let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u);
603
+
604
+ let scale0 = f32(sc16_0 & 0xFFu);
605
+ let scale1 = f32((sc16_0 >> 8u) & 0xFFu);
606
+ let min0 = f32(sc16_1 & 0xFFu);
607
+ let min1 = f32((sc16_1 >> 8u) & 0xFFu);
608
+ let scale2 = f32(sc16_2 & 0xFFu);
609
+ let scale3 = f32((sc16_2 >> 8u) & 0xFFu);
610
+ let min2 = f32(sc16_3 & 0xFFu);
611
+ let min3 = f32((sc16_3 >> 8u) & 0xFFu);
612
+
613
+ let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset);
614
+ let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset);
615
+
616
+ var dot = vec4<f32>(0.0, 0.0, 0.0, 0.0);
617
+ var sumx = vec4<f32>(0.0, 0.0, 0.0, 0.0);
618
+ for (var i = 0u; i < 4u; i++) {
619
+ let q1b = byte_of(q1_u32, i);
620
+ let q2b = byte_of(q2_u32, i);
621
+ dot[0] += x_block[i] * f32(q1b & 0x0Fu);
622
+ dot[1] += x_block[i + 4u] * f32(q1b >> 4u);
623
+ dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu);
624
+ dot[3] += x_block[i + 12u] * f32(q2b >> 4u);
625
+ sumx[0] += x_block[i];
626
+ sumx[1] += x_block[i + 4u];
627
+ sumx[2] += x_block[i + 8u];
628
+ sumx[3] += x_block[i + 12u];
629
+ }
630
+
631
+ acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3)
632
+ - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3);
633
+ }
634
+ }
635
+ }
636
+
637
+ return acc;
638
+ }
639
+ #endif
640
+
641
+ #ifdef MUL_ACC_Q5_K
642
+ #define BLOCK_SIZE 256
643
+ #define BLOCK_SIZE_BYTES 176
644
+ #define THREADS_PER_BLOCK 16
645
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
646
+ var acc: array<f32, OUTPUTS_PER_WG>;
647
+
648
+ let tid = thread_id % THREADS_PER_BLOCK;
649
+ let block_group = thread_id / THREADS_PER_BLOCK;
650
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
651
+
652
+ let il = tid / 4u;
653
+ let ir = tid % 4u;
654
+ let im = il / 2u;
655
+ let in = il % 2u;
656
+ let l0 = 4u * (2u * ir + in);
657
+
658
+ let y_offset = 64u * im + l0;
659
+ let q_offset = 48u + 32u * im + l0;
660
+ let qh_offset = 16u + 8u * ir + 4u * in;
661
+ let sc0_byte = 4u + im * 2u;
662
+ let sc2_byte = 4u + (im + 2u) * 2u;
663
+ let sc4_byte = 4u + (im + 4u) * 2u;
664
+
665
+ let hm1 = 1u << (2u * im);
666
+ let hm2 = hm1 << 1u;
667
+ let hm3 = hm1 << 4u;
668
+ let hm4 = hm2 << 4u;
669
+
670
+ let num_blocks = params.k / BLOCK_SIZE;
671
+
672
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
673
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
674
+ var x_block: array<f32, 16>;
675
+ for (var i = 0u; i < 4u; i++) {
676
+ x_block[i] = f32(src1[x_base + i]);
677
+ x_block[i + 4u] = f32(src1[x_base + 32u + i]);
678
+ x_block[i + 8u] = f32(src1[x_base + 128u + i]);
679
+ x_block[i + 12u] = f32(src1[x_base + 160u + i]);
680
+ }
681
+
682
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
683
+ let output_row = row_base + row;
684
+ if (output_row < params.m) {
685
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
686
+
687
+ let d = f32(load_f16_at_src0(block_byte_base + 0u));
688
+ let dmin = f32(load_f16_at_src0(block_byte_base + 2u));
689
+
690
+ let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte);
691
+ let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u);
692
+ let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte);
693
+ let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u);
694
+ let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte);
695
+ let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u);
696
+
697
+ let sc16_0 = sc0 & 0x3F3Fu;
698
+ let sc16_1 = sc2 & 0x3F3Fu;
699
+ let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u);
700
+ let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u);
701
+
702
+ let f0 = f32(sc16_0 & 0xFFu);
703
+ let f1 = f32((sc16_0 >> 8u) & 0xFFu);
704
+ let m0 = f32(sc16_1 & 0xFFu);
705
+ let m1 = f32((sc16_1 >> 8u) & 0xFFu);
706
+ let f4 = f32(sc16_2 & 0xFFu);
707
+ let f5 = f32((sc16_2 >> 8u) & 0xFFu);
708
+ let m4 = f32(sc16_3 & 0xFFu);
709
+ let m5 = f32((sc16_3 >> 8u) & 0xFFu);
710
+
711
+ let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset);
712
+ let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u);
713
+ let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset);
714
+
715
+ var vals = vec4<f32>(0.0, 0.0, 0.0, 0.0);
716
+ var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0);
717
+ for (var i = 0u; i < 4u; i++) {
718
+ let q1b = byte_of(q1_u32, i);
719
+ let q2b = byte_of(q2_u32, i);
720
+ let qhb = byte_of(qh_u32, i);
721
+
722
+ let yl0 = x_block[i];
723
+ let yl8 = x_block[i + 4u];
724
+ let yh0 = x_block[i + 8u];
725
+ let yh8 = x_block[i + 12u];
726
+
727
+ sumy[0] += yl0;
728
+ sumy[1] += yl8;
729
+ sumy[2] += yh0;
730
+ sumy[3] += yh8;
731
+
732
+ let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u));
733
+ let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u));
734
+ let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u));
735
+ let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u));
736
+
737
+ vals[0] += yl0 * q0;
738
+ vals[1] += yl8 * q1;
739
+ vals[2] += yh0 * q2;
740
+ vals[3] += yh8 * q3;
741
+ }
742
+
743
+ acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3])
744
+ - dmin * (sumy[0] * m0 + sumy[1] * m1 +
745
+ sumy[2] * m4 + sumy[3] * m5);
746
+ }
747
+ }
748
+ }
749
+
750
+ return acc;
751
+ }
752
+ #endif
753
+
754
+ #ifdef MUL_ACC_Q6_K
755
+ #define BLOCK_SIZE 256
756
+ #define BLOCK_SIZE_BYTES 210
757
+ #define THREADS_PER_BLOCK 16
758
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
759
+ var acc: array<f32, OUTPUTS_PER_WG>;
760
+
761
+ let tid = thread_id % THREADS_PER_BLOCK;
762
+ let block_group = thread_id / THREADS_PER_BLOCK;
763
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
764
+
765
+ let ip = tid / 8u;
766
+ let il = tid % 8u;
767
+ let l0 = 4u * il;
768
+ let is = 8u * ip + l0 / 16u;
769
+
770
+ let y_offset = 128u * ip + l0;
771
+ let q_offset_l = 64u * ip + l0;
772
+ let q_offset_h = 32u * ip + l0;
773
+
774
+ let num_blocks = params.k / BLOCK_SIZE;
775
+ let sc_base_byte = 192u + (is & ~3u);
776
+ let sc_byte_pos = is & 3u;
777
+
778
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
779
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
780
+ var x_block: array<f32, 16>;
781
+ for (var l = 0u; l < 4u; l++) {
782
+ x_block[l] = f32(src1[x_base + l]);
783
+ x_block[l + 4u] = f32(src1[x_base + 32u + l]);
784
+ x_block[l + 8u] = f32(src1[x_base + 64u + l]);
785
+ x_block[l + 12u] = f32(src1[x_base + 96u + l]);
786
+ }
787
+
788
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
789
+ let output_row = row_base + row;
790
+ if (output_row < params.m) {
791
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
792
+
793
+ let d = f32(load_f16_at_src0(block_byte_base + 208u));
794
+ let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l);
795
+ let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u);
796
+ let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h);
797
+ let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte);
798
+ let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u);
799
+
800
+ let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
801
+ let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
802
+ let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
803
+ let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
804
+
805
+ var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
806
+
807
+ for (var l = 0u; l < 4u; l++) {
808
+ let q1b = byte_of(ql1_u32, l);
809
+ let q2b = byte_of(ql2_u32, l);
810
+ let qhb = byte_of(qh_u32, l);
811
+
812
+ let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
813
+ let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
814
+ let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32);
815
+ let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
816
+
817
+ sums[0] += x_block[l] * dq0;
818
+ sums[1] += x_block[l + 4u] * dq1;
819
+ sums[2] += x_block[l + 8u] * dq2;
820
+ sums[3] += x_block[l + 12u] * dq3;
821
+ }
822
+
823
+ acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
824
+ sums[2] * f32(sc4) + sums[3] * f32(sc6));
825
+ }
826
+ }
827
+ }
828
+
829
+ return acc;
830
+ }
831
+ #endif
832
+
833
+ #ifdef MUL_ACC_IQ1_S
834
+ #define BLOCK_SIZE 256
835
+ #define BLOCK_SIZE_BYTES 50
836
+ #define THREADS_PER_BLOCK 16
837
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
838
+ var acc: array<f32, OUTPUTS_PER_WG>;
839
+
840
+ let tid = thread_id % THREADS_PER_BLOCK;
841
+ let block_group = thread_id / THREADS_PER_BLOCK;
842
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
843
+
844
+ let sub_blk = tid / 2u;
845
+ let half = tid % 2u;
846
+ let slot0 = half * 2u;
847
+ let y_offset = sub_blk * 32u + slot0 * 8u;
848
+
849
+ let num_blocks = params.k / BLOCK_SIZE;
850
+
851
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
852
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
853
+ var x_block: array<f32, 16>;
854
+ for (var i = 0u; i < 16u; i++) {
855
+ x_block[i] = f32(src1[x_base + i]);
856
+ }
857
+
858
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
859
+ let output_row = row_base + row;
860
+ if (output_row < params.m) {
861
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
862
+
863
+ let d = f32(load_f16_at_src0(block_byte_base));
864
+ let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
865
+ let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
866
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
867
+ let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
868
+
869
+ var row_sum = 0.0;
870
+ for (var ll = 0u; ll < 2u; ll++) {
871
+ let l = slot0 + ll;
872
+ let qs_byte = get_byte(qs_w, l);
873
+ let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
874
+ let gw = iq1_grid[ig / 16u];
875
+ let bit_base = (ig % 16u) * 2u;
876
+ for (var j = 0u; j < 8u; j++) {
877
+ let g = (gw >> (bit_base + j * 2u)) & 3u;
878
+ let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
879
+ row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
880
+ }
881
+ }
882
+ acc[row] += row_sum;
883
+ }
884
+ }
885
+ }
886
+
887
+ return acc;
888
+ }
889
+ #endif
890
+
891
+ #ifdef MUL_ACC_IQ1_M
892
+ #define BLOCK_SIZE 256
893
+ #define BLOCK_SIZE_BYTES 56
894
+ #define THREADS_PER_BLOCK 16
895
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
896
+ var acc: array<f32, OUTPUTS_PER_WG>;
897
+
898
+ let tid = thread_id % THREADS_PER_BLOCK;
899
+ let block_group = thread_id / THREADS_PER_BLOCK;
900
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
901
+
902
+ let sub_blk = tid / 2u;
903
+ let half = tid % 2u;
904
+ let slot0 = half * 2u;
905
+ let y_offset = sub_blk * 32u + slot0 * 8u;
906
+
907
+ let num_blocks = params.k / BLOCK_SIZE;
908
+
909
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
910
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
911
+ var x_block: array<f32, 16>;
912
+ for (var i = 0u; i < 16u; i++) {
913
+ x_block[i] = f32(src1[x_base + i]);
914
+ }
915
+
916
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
917
+ let output_row = row_base + row;
918
+ if (output_row < params.m) {
919
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
920
+
921
+ let sc_lo = load_u32_at_src0(block_byte_base + 48u);
922
+ let sc_hi = load_u32_at_src0(block_byte_base + 52u);
923
+ let sc0 = sc_lo & 0xFFFFu;
924
+ let sc1 = (sc_lo >> 16u) & 0xFFFFu;
925
+ let sc2 = sc_hi & 0xFFFFu;
926
+ let sc3 = (sc_hi >> 16u) & 0xFFFFu;
927
+ let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
928
+ let d = f32(bitcast<vec2<f16>>(d_bits)[0]);
929
+
930
+ let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
931
+ select(sc0, sc1, sub_blk >= 2u),
932
+ sub_blk < 4u);
933
+
934
+ let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
935
+ let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
936
+ let qh_lo = qh & 0xFFu;
937
+ let qh_hi = (qh >> 8u) & 0xFFu;
938
+
939
+ var row_sum = 0.0;
940
+ for (var ll = 0u; ll < 2u; ll++) {
941
+ let l = slot0 + ll;
942
+ let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
943
+ let sub_scale = (sc_u16 >> bit_off) & 0x7u;
944
+ let dl = d * f32(2u * sub_scale + 1u);
945
+ let qh_byte = select(qh_lo, qh_hi, l >= 2u);
946
+ let ll2 = l % 2u;
947
+ let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
948
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
949
+ let ig = grid_idx * 8u;
950
+ let gw = iq1_grid[ig / 16u];
951
+ let bit_base = (ig % 16u) * 2u;
952
+ for (var j = 0u; j < 8u; j++) {
953
+ let g = (gw >> (bit_base + j * 2u)) & 3u;
954
+ let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
955
+ row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
956
+ }
957
+ }
958
+ acc[row] += row_sum;
959
+ }
960
+ }
961
+ }
962
+
963
+ return acc;
964
+ }
965
+ #endif
966
+
967
+ #ifdef MUL_ACC_IQ2_XXS
968
+ #define BLOCK_SIZE 256
969
+ #define BLOCK_SIZE_BYTES 66
970
+ #define THREADS_PER_BLOCK 16
971
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
972
+ var acc: array<f32, OUTPUTS_PER_WG>;
973
+
974
+ let tid = thread_id % THREADS_PER_BLOCK;
975
+ let block_group = thread_id / THREADS_PER_BLOCK;
976
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
977
+
978
+ let sub_blk = tid / 2u;
979
+ let half = tid % 2u;
980
+ let slot0 = half * 2u;
981
+ let y_offset = sub_blk * 32u + slot0 * 8u;
982
+
983
+ let num_blocks = params.k / BLOCK_SIZE;
984
+
985
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
986
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
987
+ var x_block: array<f32, 16>;
988
+ for (var i = 0u; i < 16u; i++) {
989
+ x_block[i] = f32(src1[x_base + i]);
990
+ }
991
+
992
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
993
+ let output_row = row_base + row;
994
+ if (output_row < params.m) {
995
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
996
+ let d = f32(load_f16_at_src0(block_byte_base));
997
+ let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
998
+ let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
999
+ let ls = aux_hi >> 28u;
1000
+ let db = d * (0.5 + f32(ls)) * 0.25;
1001
+
1002
+ var row_sum = 0.0;
1003
+ for (var ll = 0u; ll < 2u; ll++) {
1004
+ let l = slot0 + ll;
1005
+ let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
1006
+ let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
1007
+ let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
1008
+ let gw_lo = iq2xxs_grid[grid_idx * 2u];
1009
+ let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
1010
+ for (var j = 0u; j < 8u; j++) {
1011
+ let gw = select(gw_hi, gw_lo, j < 4u);
1012
+ let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
1013
+ let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
1014
+ row_sum += db * b * s * x_block[ll * 8u + j];
1015
+ }
1016
+ }
1017
+ acc[row] += row_sum;
1018
+ }
1019
+ }
1020
+ }
1021
+
1022
+ return acc;
1023
+ }
1024
+ #endif
1025
+
1026
+ #ifdef MUL_ACC_IQ2_XS
1027
+ #define BLOCK_SIZE 256
1028
+ #define BLOCK_SIZE_BYTES 74
1029
+ #define THREADS_PER_BLOCK 16
1030
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1031
+ var acc: array<f32, OUTPUTS_PER_WG>;
1032
+
1033
+ let tid = thread_id % THREADS_PER_BLOCK;
1034
+ let block_group = thread_id / THREADS_PER_BLOCK;
1035
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1036
+
1037
+ let sub_blk = tid / 2u;
1038
+ let half = tid % 2u;
1039
+ let slot0 = half * 2u;
1040
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1041
+
1042
+ let num_blocks = params.k / BLOCK_SIZE;
1043
+
1044
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1045
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1046
+ var x_block: array<f32, 16>;
1047
+ for (var i = 0u; i < 16u; i++) {
1048
+ x_block[i] = f32(src1[x_base + i]);
1049
+ }
1050
+
1051
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1052
+ let output_row = row_base + row;
1053
+ if (output_row < params.m) {
1054
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1055
+ let d = f32(load_f16_at_src0(block_byte_base));
1056
+ let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
1057
+ let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
1058
+ let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
1059
+ let scales_byte = get_byte(scales_word, sub_blk % 4u);
1060
+
1061
+ var row_sum = 0.0;
1062
+ for (var ll = 0u; ll < 2u; ll++) {
1063
+ let l = slot0 + ll;
1064
+ let qs_word = select(qs_hi, qs_lo, l < 2u);
1065
+ let half2 = (l % 2u) * 16u;
1066
+ let qs_val = (qs_word >> half2) & 0xFFFFu;
1067
+ let grid_idx = qs_val & 0x1FFu;
1068
+ let signs_idx = (qs_val >> 9u) & 0x7Fu;
1069
+ let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
1070
+ let db = d * (0.5 + f32(sub_scale)) * 0.25;
1071
+ let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
1072
+ let gw_lo = iq2xs_grid[grid_idx * 2u];
1073
+ let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
1074
+ for (var j = 0u; j < 8u; j++) {
1075
+ let gw = select(gw_hi, gw_lo, j < 4u);
1076
+ let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
1077
+ let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
1078
+ row_sum += db * b * s * x_block[ll * 8u + j];
1079
+ }
1080
+ }
1081
+ acc[row] += row_sum;
1082
+ }
1083
+ }
1084
+ }
1085
+
1086
+ return acc;
1087
+ }
1088
+ #endif
1089
+
1090
+ #ifdef MUL_ACC_IQ2_S
1091
+ #define BLOCK_SIZE 256
1092
+ #define BLOCK_SIZE_BYTES 82
1093
+ #define THREADS_PER_BLOCK 16
1094
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1095
+ var acc: array<f32, OUTPUTS_PER_WG>;
1096
+
1097
+ let tid = thread_id % THREADS_PER_BLOCK;
1098
+ let block_group = thread_id / THREADS_PER_BLOCK;
1099
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1100
+
1101
+ let sub_blk = tid / 2u;
1102
+ let half = tid % 2u;
1103
+ let slot0 = half * 2u;
1104
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1105
+
1106
+ let num_blocks = params.k / BLOCK_SIZE;
1107
+
1108
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1109
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1110
+ var x_block: array<f32, 16>;
1111
+ for (var i = 0u; i < 16u; i++) {
1112
+ x_block[i] = f32(src1[x_base + i]);
1113
+ }
1114
+
1115
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1116
+ let output_row = row_base + row;
1117
+ if (output_row < params.m) {
1118
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1119
+ let d = f32(load_f16_at_src0(block_byte_base));
1120
+ let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
1121
+ let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
1122
+ let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
1123
+ let qh_byte = get_byte(qh_word, sub_blk % 4u);
1124
+ let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
1125
+ let scales_byte = get_byte(sc_word, sub_blk % 4u);
1126
+
1127
+ var row_sum = 0.0;
1128
+ for (var ll = 0u; ll < 2u; ll++) {
1129
+ let l = slot0 + ll;
1130
+ let qs_byte = get_byte(qs_w, l);
1131
+ let sign_byte = get_byte(sg_w, l);
1132
+ let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
1133
+ let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
1134
+ let db = d * (0.5 + f32(sub_scale)) * 0.25;
1135
+ let gw_lo = iq2s_grid[grid_idx * 2u];
1136
+ let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
1137
+ for (var j = 0u; j < 8u; j++) {
1138
+ let gw = select(gw_hi, gw_lo, j < 4u);
1139
+ let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
1140
+ let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
1141
+ row_sum += db * b * s * x_block[ll * 8u + j];
1142
+ }
1143
+ }
1144
+ acc[row] += row_sum;
1145
+ }
1146
+ }
1147
+ }
1148
+
1149
+ return acc;
1150
+ }
1151
+ #endif
1152
+
1153
+ #ifdef MUL_ACC_IQ3_XXS
1154
+ #define BLOCK_SIZE 256
1155
+ #define BLOCK_SIZE_BYTES 98
1156
+ #define THREADS_PER_BLOCK 16
1157
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1158
+ var acc: array<f32, OUTPUTS_PER_WG>;
1159
+
1160
+ let tid = thread_id % THREADS_PER_BLOCK;
1161
+ let block_group = thread_id / THREADS_PER_BLOCK;
1162
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1163
+
1164
+ let sub_blk = tid / 2u;
1165
+ let half = tid % 2u;
1166
+ let slot0 = half * 2u;
1167
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1168
+
1169
+ let num_blocks = params.k / BLOCK_SIZE;
1170
+
1171
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1172
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1173
+ var x_block: array<f32, 16>;
1174
+ for (var i = 0u; i < 16u; i++) {
1175
+ x_block[i] = f32(src1[x_base + i]);
1176
+ }
1177
+
1178
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1179
+ let output_row = row_base + row;
1180
+ if (output_row < params.m) {
1181
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1182
+ let d = f32(load_f16_at_src0(block_byte_base));
1183
+ let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
1184
+ let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
1185
+ let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
1186
+ let ls = aux >> 28u;
1187
+ let db = d * (0.5 + f32(ls)) * 0.5;
1188
+
1189
+ var row_sum = 0.0;
1190
+ for (var ll = 0u; ll < 2u; ll++) {
1191
+ let l = slot0 + ll;
1192
+ let qs_word = select(qs_hi, qs_lo, l < 2u);
1193
+ let byte_pos = (l % 2u) * 2u;
1194
+ let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
1195
+ let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
1196
+ let signs_idx = (aux >> (7u * l)) & 0x7Fu;
1197
+ let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
1198
+ let grid1 = iq3xxs_grid[grid_idx_0];
1199
+ let grid2 = iq3xxs_grid[grid_idx_1];
1200
+ for (var j = 0u; j < 4u; j++) {
1201
+ let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
1202
+ let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
1203
+ let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
1204
+ let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
1205
+ row_sum += db * b1 * s1 * x_block[ll * 8u + j];
1206
+ row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
1207
+ }
1208
+ }
1209
+ acc[row] += row_sum;
1210
+ }
1211
+ }
1212
+ }
1213
+
1214
+ return acc;
1215
+ }
1216
+ #endif
1217
+
1218
+ #ifdef MUL_ACC_IQ3_S
1219
+ #define BLOCK_SIZE 256
1220
+ #define BLOCK_SIZE_BYTES 110
1221
+ #define THREADS_PER_BLOCK 16
1222
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1223
+ var acc: array<f32, OUTPUTS_PER_WG>;
1224
+
1225
+ let tid = thread_id % THREADS_PER_BLOCK;
1226
+ let block_group = thread_id / THREADS_PER_BLOCK;
1227
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1228
+
1229
+ let sub_blk = tid / 2u;
1230
+ let half = tid % 2u;
1231
+ let slot0 = half * 2u;
1232
+ let y_offset = sub_blk * 32u + slot0 * 8u;
1233
+
1234
+ let num_blocks = params.k / BLOCK_SIZE;
1235
+
1236
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1237
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1238
+ var x_block: array<f32, 16>;
1239
+ for (var i = 0u; i < 16u; i++) {
1240
+ x_block[i] = f32(src1[x_base + i]);
1241
+ }
1242
+
1243
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1244
+ let output_row = row_base + row;
1245
+ if (output_row < params.m) {
1246
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1247
+ let d = f32(load_f16_at_src0(block_byte_base));
1248
+ let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
1249
+ let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
1250
+ let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
1251
+ let qh_byte = get_byte(qh_word, sub_blk % 4u);
1252
+ let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
1253
+ let sc_word = load_u32_at_src0(block_byte_base + 106u);
1254
+ let scales_byte = get_byte(sc_word, sub_blk / 2u);
1255
+ let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
1256
+ let db = d * (1.0 + 2.0 * f32(sub_scale));
1257
+
1258
+ var row_sum = 0.0;
1259
+ for (var ll = 0u; ll < 2u; ll++) {
1260
+ let l = slot0 + ll;
1261
+ let qs_word = select(qs_hi, qs_lo, l < 2u);
1262
+ let byte_pos = (l % 2u) * 2u;
1263
+ let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
1264
+ let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
1265
+ let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
1266
+ let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
1267
+ let sign_byte = get_byte(sg_w, l);
1268
+ let grid1 = iq3s_grid[grid_idx_1];
1269
+ let grid2 = iq3s_grid[grid_idx_2];
1270
+ for (var j = 0u; j < 4u; j++) {
1271
+ let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
1272
+ let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
1273
+ let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
1274
+ let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
1275
+ row_sum += db * b1 * s1 * x_block[ll * 8u + j];
1276
+ row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
1277
+ }
1278
+ }
1279
+ acc[row] += row_sum;
1280
+ }
1281
+ }
1282
+ }
1283
+
1284
+ return acc;
1285
+ }
1286
+ #endif
1287
+
1288
+ #ifdef MUL_ACC_IQ4_NL
1289
+ #define BLOCK_SIZE 32
1290
+ #define BLOCK_SIZE_BYTES 18
1291
+ #define THREADS_PER_BLOCK 4
1292
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
1293
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1294
+ var acc: array<f32, OUTPUTS_PER_WG>;
1295
+
1296
+ let num_blocks = params.k / BLOCK_SIZE;
1297
+ let thread_within_block = thread_id % THREADS_PER_BLOCK;
1298
+ for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
1299
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
1300
+ var x_block: array<f32, ELEMS_PER_THREAD>;
1301
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
1302
+ x_block[i] = f32(src1[x_base + i]);
1303
+ x_block[i + 4u] = f32(src1[x_base + i + 16u]);
1304
+ }
1305
+
1306
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1307
+ let output_row = row_base + row;
1308
+ if (output_row < params.m) {
1309
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1310
+ let d = f32(load_f16_at_src0(block_byte_base));
1311
+ var row_sum = 0.0;
1312
+
1313
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
1314
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
1315
+ let q_byte = get_byte(q_packed, byte_idx);
1316
+ let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
1317
+ let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
1318
+ row_sum += q_lo * x_block[byte_idx];
1319
+ row_sum += q_hi * x_block[byte_idx + 4u];
1320
+ }
1321
+ acc[row] += row_sum;
1322
+ }
1323
+ }
1324
+ }
1325
+
1326
+ return acc;
1327
+ }
1328
+ #endif
1329
+
1330
+ #ifdef MUL_ACC_IQ4_XS
1331
+ #define BLOCK_SIZE 256
1332
+ #define BLOCK_SIZE_BYTES 136
1333
+ #define THREADS_PER_BLOCK 16
1334
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1335
+ var acc: array<f32, OUTPUTS_PER_WG>;
1336
+
1337
+ let tid = thread_id % THREADS_PER_BLOCK;
1338
+ let block_group = thread_id / THREADS_PER_BLOCK;
1339
+ let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
1340
+
1341
+ let sub_blk = tid / 2u;
1342
+ let half = tid % 2u;
1343
+ let y_offset = sub_blk * 32u + half * 16u;
1344
+
1345
+ let num_blocks = params.k / BLOCK_SIZE;
1346
+
1347
+ for (var block = block_group; block < num_blocks; block += num_block_groups) {
1348
+ let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
1349
+ var x_block: array<f32, 16>;
1350
+ for (var i = 0u; i < 16u; i++) {
1351
+ x_block[i] = f32(src1[x_base + i]);
1352
+ }
1353
+
1354
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1355
+ let output_row = row_base + row;
1356
+ if (output_row < params.m) {
1357
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1358
+ let d = f32(load_f16_at_src0(block_byte_base));
1359
+ let scales_h = load_u16_at_src0(block_byte_base + 2u);
1360
+ let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
1361
+ let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
1362
+ let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
1363
+ let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
1364
+ let ls = i32(sl | (sh_bits << 4u));
1365
+ let dl = d * f32(ls - 32);
1366
+
1367
+ let qs_byte_off = 8u + sub_blk * 16u;
1368
+ let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
1369
+ let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
1370
+ let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
1371
+ let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);
1372
+
1373
+ var row_sum = 0.0;
1374
+ for (var i = 0u; i < 16u; i++) {
1375
+ let q_word = select(
1376
+ select(q_w0, q_w1, i >= 4u),
1377
+ select(q_w2, q_w3, i >= 12u),
1378
+ i >= 8u);
1379
+ let q_byte = get_byte(q_word, i % 4u);
1380
+ let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
1381
+ row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
1382
+ }
1383
+ acc[row] += row_sum;
1384
+ }
1385
+ }
1386
+ }
1387
+
1388
+ return acc;
1389
+ }
1390
+ #endif
1391
+
1392
+ #ifdef MUL_ACC_MXFP4
1393
+ #define BLOCK_SIZE 32
1394
+ #define BLOCK_SIZE_BYTES 17
1395
+ #define THREADS_PER_BLOCK 4
1396
+ #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
1397
+ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
1398
+ var acc: array<f32, OUTPUTS_PER_WG>;
1399
+
1400
+ let num_blocks = params.k / BLOCK_SIZE;
1401
+ let thread_within_block = thread_id % 4;
1402
+ for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
1403
+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
1404
+ var x_block: array<f32, ELEMS_PER_THREAD>;
1405
+ for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
1406
+ x_block[i] = f32(src1[x_base + i]);
1407
+ x_block[i + 4] = f32(src1[x_base + i + 16]);
1408
+ }
1409
+
1410
+ for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
1411
+ let output_row = row_base + row;
1412
+ if (output_row < params.m) {
1413
+ let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
1414
+ let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
1415
+ let e = ldexp(1.0, i32(eu8) - 128);
1416
+ var row_sum = 0.0;
1417
+ let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block);
1418
+ for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
1419
+ let q_byte = get_byte(q_packed, byte_idx);
1420
+ let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e;
1421
+ let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e;
1422
+ row_sum += q_lo * x_block[byte_idx];
1423
+ row_sum += q_hi * x_block[byte_idx + 4u];
1424
+ }
1425
+ acc[row] += row_sum;
1426
+ }
1427
+ }
1428
+ }
1429
+
1430
+ return acc;
1431
+ }
1432
+ #endif