whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) {
25
25
  }
26
26
  #endif // SCALAR
27
27
 
28
+ #define QUANT_SHMEM shmem
29
+ #define QUANT_OUT_TYPE f16
30
+ #include "quant_inner_loops.tmpl"
31
+
28
32
  #ifdef INIT_SRC0_SHMEM_FLOAT
29
33
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
30
34
  for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
@@ -42,6 +46,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
42
46
  }
43
47
  #endif // INIT_SRC0_SHMEM_FLOAT
44
48
 
49
+ #ifndef MUL_MAT_ID
45
50
  #ifdef INIT_SRC1_SHMEM_FLOAT
46
51
  fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
47
52
  for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
@@ -58,307 +63,530 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
58
63
  }
59
64
  }
60
65
  #endif // INIT_SRC1_SHMEM_FLOAT
66
+ #endif
61
67
 
62
- #ifdef INIT_SRC0_SHMEM_Q4_0
63
- const BLOCK_SIZE = 32u;
64
- // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
65
- override BLOCKS_K = TILE_K/BLOCK_SIZE;
66
- const NQ = 16u;
67
- const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
68
- const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
69
- const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
68
+ #ifdef INIT_SRC0_SHMEM_Q1_0
69
+ const BLOCK_SIZE = 128u;
70
+ const BLOCK_SIZE_BYTES = 18u;
71
+ const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration
70
72
 
71
73
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
72
74
  for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
73
- let blck_idx = i / BLOCK_SIZE;
74
- let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
75
- let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
76
-
77
- let tile_m = blck_idx / BLOCKS_K;
75
+ let tile_m = i / TILE_K;
76
+ let tile_k_start = i % TILE_K;
78
77
  let global_m = offset_m + tile_m;
79
- let block_k = blck_idx % BLOCKS_K;
80
- let global_k = k_outer / BLOCK_SIZE + block_k;
78
+ let global_k_start = k_outer + tile_k_start;
81
79
 
82
- if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
83
- let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
84
- let scale_idx = src0_idx * F16_PER_BLOCK;
85
- let d = src0[scale_idx];
86
-
87
- for (var j = 0u; j < F16_PER_THREAD; j += 2) {
88
- let q_0 = src0[scale_idx + 1u + block_offset + j];
89
- let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
80
+ if (global_m >= params.m) {
81
+ break;
82
+ }
90
83
 
91
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
92
- for (var k = 0u; k < 4u; k++) {
93
- let q_byte = get_byte(q_packed, k);
94
- let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
95
- let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
96
- shmem[shmem_idx + j * 2 + k] = q_lo;
97
- shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
98
- }
84
+ let block_k = global_k_start / BLOCK_SIZE;
85
+ let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
86
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
87
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
88
+ let d = load_f16_at_src0(block_byte_base);
89
+ let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;
90
+
91
+ for (var bit = 0u; bit < NQ; bit++) {
92
+ let global_k = global_k_start + bit;
93
+ if (global_k < params.k) {
94
+ shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
99
95
  }
100
96
  }
101
97
  }
102
98
  }
103
- #endif // INIT_SRC0_SHMEM_Q4_0
99
+ #endif // INIT_SRC0_SHMEM_Q1_0
104
100
 
105
- #ifdef INIT_SRC0_SHMEM_Q4_1
101
+ #if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
106
102
  const BLOCK_SIZE = 32u;
107
103
  // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
108
104
  override BLOCKS_K = TILE_K/BLOCK_SIZE;
109
105
  const NQ = 16u;
110
- const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
111
- const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
112
- const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
106
+ #if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1)
107
+ const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
108
+ #else
109
+ const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
110
+ #endif
111
+ const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
113
112
 
114
113
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
115
114
  for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
116
- let blck_idx = i / BLOCK_SIZE;
117
- let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
118
- let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
115
+ let block_idx = i / BLOCK_SIZE;
116
+ let block_offset = (i % BLOCK_SIZE) / NQ;
117
+ let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
119
118
 
120
- let tile_m = blck_idx / BLOCKS_K;
119
+ let tile_m = block_idx / BLOCKS_K;
121
120
  let global_m = offset_m + tile_m;
122
- let block_k = blck_idx % BLOCKS_K;
123
- let global_k = k_outer / BLOCK_SIZE + block_k;
124
-
125
- if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
126
- let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
127
- let scale_idx = src0_idx * F16_PER_BLOCK;
128
- let d = src0[scale_idx];
129
- let m = src0[scale_idx + 1u];
121
+ let block_k = block_idx % BLOCKS_K;
122
+ let global_block_k = k_outer / BLOCK_SIZE + block_k;
130
123
 
131
- for (var j = 0u; j < F16_PER_THREAD; j += 2) {
132
- let q_0 = src0[scale_idx + 2u + block_offset + j];
133
- let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
124
+ if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
125
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
134
126
 
135
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
136
- for (var k = 0u; k < 4u; k++) {
127
+ #ifdef INIT_SRC0_SHMEM_Q4_0
128
+ let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
129
+ let d = load_f16_at_src0(block_byte_base);
130
+
131
+ // load NQ(16) weights
132
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
133
+ let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
134
+ let q_packed = load_u32_at_src0(q_byte_offset);
135
+ dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
136
+ }
137
+ #elif INIT_SRC0_SHMEM_Q4_1
138
+ let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
139
+ let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
140
+ let d = f16(dm[0]);
141
+ let m = f16(dm[1]);
142
+
143
+ // load NQ(16) weights
144
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
145
+ let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
146
+ let q_packed = load_u32_at_src0(q_byte_offset);
147
+
148
+ for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
137
149
  let q_byte = get_byte(q_packed, k);
138
150
  let q_lo = f16(q_byte & 0xF) * d + m;
139
151
  let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
140
- shmem[shmem_idx + j * 2 + k] = q_lo;
141
- shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
152
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
153
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
142
154
  }
143
155
  }
144
- }
145
- }
146
- }
147
- #endif // INIT_SRC0_SHMEM_Q4_1
156
+ #elif INIT_SRC0_SHMEM_Q5_0
157
+ let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
148
158
 
149
- #ifdef INIT_SRC0_SHMEM_Q5_0
150
- // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
151
- const BLOCK_SIZE = 32u;
152
- // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
153
- // tile_k is defined as 32u, so blocks_k ends up being 1 always
154
- override BLOCKS_K = TILE_K / BLOCK_SIZE;
155
- const NQ = 16u;
156
- const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
157
- const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
158
- const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
159
-
160
- fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
159
+ let d = load_f16_at_src0(block_byte_base);
160
+ let qh_packed = load_u32_at_src0(block_byte_base + 2u);
161
161
 
162
- for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
163
- let blck_idx = i / BLOCK_SIZE;
164
- let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
165
- let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
166
-
167
- let tile_m = blck_idx / BLOCKS_K;
168
- let global_m = offset_m + tile_m;
169
- let block_k = blck_idx % BLOCKS_K;
170
- let global_k = k_outer / BLOCK_SIZE + block_k;
162
+ // load NQ(16) weights
163
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
164
+ let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
165
+ let q_packed = load_u32_at_src0(q_byte_offset);
171
166
 
172
- if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
173
- let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
174
- let scale_idx = src0_idx * F16_PER_BLOCK;
175
-
176
- let d = src0[scale_idx];
177
- let qh0 = src0[scale_idx + 1u];
178
- let qh1 = src0[scale_idx + 2u];
179
- let qh_packed = bitcast<u32>(vec2(qh0, qh1));
180
-
181
- for (var j = 0u; j < 2; j++) {
182
- let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
183
- let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
167
+ for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
168
+ let q_byte = get_byte(q_packed, k);
184
169
 
185
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
170
+ let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k;
171
+ let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10;
172
+ let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
173
+ let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10;
174
+ let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
175
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
176
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
177
+ }
178
+ }
179
+ #elif INIT_SRC0_SHMEM_Q5_1
180
+ let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
186
181
 
187
- let j_adjusted = j + (block_offset / 2u);
182
+ let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
183
+ let d = f16(dm[0]);
184
+ let m = f16(dm[1]);
185
+ let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u);
188
186
 
187
+ // load NQ(16) weights
188
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
189
+ let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
190
+ let q_packed = load_u32_at_src0_aligned(q_byte_offset);
189
191
 
190
- for (var k = 0u; k < 4u; k++) {
192
+ for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
191
193
  let q_byte = get_byte(q_packed, k);
192
194
 
193
- let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
194
- let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
195
- let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
196
- let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
197
-
198
- shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
199
- shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
195
+ let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k;
196
+ let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10;
197
+ let q_hi = f16(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
198
+ let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10;
199
+ let q_lo = f16((q_byte & 0xF) | qh_lo) * d + m;
200
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
201
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
200
202
  }
201
203
  }
204
+ #elif INIT_SRC0_SHMEM_Q8_0
205
+ let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
206
+ let d = load_f16_at_src0(block_byte_base);
207
+
208
+ // load NQ(16) weights
209
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
210
+ let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
211
+ let q_packed = load_u32_at_src0(q_byte_offset);
212
+ dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
213
+ }
214
+ #elif INIT_SRC0_SHMEM_Q8_1
215
+ let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
216
+ let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
217
+ let d = f16(dm[0]);
218
+ let m = f16(dm[1]);
219
+
220
+ // load NQ(16) weights
221
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
222
+ let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
223
+ let q_packed = load_u32_at_src0(q_byte_offset);
224
+ for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
225
+ let q_byte = get_byte_i32(q_packed, k);
226
+ let q_val = f16(q_byte) * d + m;
227
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
228
+ }
229
+ }
230
+ #elif INIT_SRC0_SHMEM_MXFP4
231
+ let block_byte_base = src0_idx * 17u;
232
+ let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
233
+ let e = ldexp(1.0, i32(eu8) - 128);
234
+
235
+ // load NQ(16) weights
236
+ for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
237
+ let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
238
+ let q_packed = load_u32_at_src0(q_byte_offset);
239
+ for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
240
+ let q_byte = get_byte(q_packed, k);
241
+ let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
242
+ let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
243
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
244
+ shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
245
+ }
246
+ }
247
+ #endif
202
248
  }
203
249
  }
204
250
  }
205
- #endif // INIT_SRC0_SHMEM_Q5_0
251
+ #endif
206
252
 
207
- #ifdef INIT_SRC0_SHMEM_Q5_1
208
- // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
209
- const BLOCK_SIZE = 32u;
210
- // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
211
- // tile_k is defined as 32u, so blocks_k ends up being 1 always
212
- override BLOCKS_K = TILE_K / BLOCK_SIZE;
213
- const NQ = 16u;
214
- const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
215
- const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
216
- const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
253
+ // k-quants
254
+ #if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
255
+ const BLOCK_SIZE = 256u;
256
+ const NQ = 4u;
217
257
 
218
- fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
258
+ fn store_shmem_kquants(val: vec4<f16>, idx: u32) {
259
+ shmem[idx] = val.x;
260
+ shmem[idx + 1] = val.y;
261
+ shmem[idx + 2] = val.z;
262
+ shmem[idx + 3] = val.w;
263
+ }
219
264
 
220
- for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
221
- let blck_idx = i / BLOCK_SIZE;
222
- let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
223
- let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
265
+ fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
266
+ return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
267
+ }
268
+
269
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
270
+ for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) {
271
+ let tile_m = elem_idx / TILE_K;
272
+ let tile_k = elem_idx % TILE_K;
224
273
 
225
- let tile_m = blck_idx / BLOCKS_K;
226
274
  let global_m = offset_m + tile_m;
227
- let block_k = blck_idx % BLOCKS_K;
228
- let global_k = k_outer / BLOCK_SIZE + block_k;
275
+ let global_k = k_outer + tile_k;
229
276
 
230
- if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
231
- let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
232
- let scale_idx = src0_idx * F16_PER_BLOCK;
277
+ if (global_m >= params.m || global_k >= params.k) {
278
+ store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx);
279
+ continue;
280
+ }
233
281
 
234
- let d = src0[scale_idx];
235
- let m = src0[scale_idx + 1u];
236
- let qh0 = src0[scale_idx + 2u];
237
- let qh1 = src0[scale_idx + 3u];
238
- let qh_packed = bitcast<u32>(vec2(qh0, qh1));
282
+ let block_k = global_k / BLOCK_SIZE;
283
+ let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
239
284
 
240
- for (var j = 0u; j < 2; j++) {
285
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
241
286
 
242
- let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
243
- let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
287
+ #ifdef INIT_SRC0_SHMEM_Q2_K
288
+ let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
289
+ let scales_byte_base = block_byte_base;
290
+ let qs_byte_base = block_byte_base + 16u;
291
+ let dm_byte_base = block_byte_base + 80u;
292
+
293
+ let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
294
+ let d = f16(d_packed[0]);
295
+ let dmin = f16(d_packed[1]);
296
+
297
+ let chunk = k_in_block / 128u;
298
+ let pos_in_chunk = k_in_block % 32u;
299
+ let sub_block = k_in_block / 16u;
300
+ let shift_phase = (k_in_block % 128u) / 32u;
301
+
302
+ // whole 2 bits (4 elems)
303
+ let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
304
+ let qs_vec4 = vec4<f16>(
305
+ f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u),
306
+ f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u),
307
+ f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u),
308
+ f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u),
309
+ );
310
+
311
+ let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block);
312
+
313
+ let dl = d * f16(scale & 0xFu);
314
+ let ml = dmin * f16(scale >> 4u);
315
+
316
+ store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
317
+ #elif INIT_SRC0_SHMEM_Q3_K
318
+ let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
319
+ let hmask_byte_base = block_byte_base + 0u;
320
+ let qs_byte_base = block_byte_base + 32u;
321
+ let scales_byte_base = block_byte_base + 96u;
322
+
323
+ let d_all = load_f16_at_src0(block_byte_base + 108u);
324
+
325
+ let chunk = k_in_block / 128u;
326
+ let pos_in_chunk = k_in_block % 32u;
327
+ let sub_block = k_in_block / 16u;
328
+ let shift_phase = (k_in_block % 128u) / 32u;
329
+
330
+ let hmask_block = pos_in_chunk;
331
+ let hmask_shift_phase = k_in_block / 32u;
332
+
333
+ // low 2 bits (4 elems)
334
+ let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block);
335
+ let q_lo2_vec4 = vec4<f16>(
336
+ f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u),
337
+ f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u),
338
+ f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u),
339
+ f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u)
340
+ );
341
+
342
+ // high 1 bit (4 elems)
343
+ let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk);
344
+ let q_hi1_vec4 = vec4<f16>(
345
+ f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)),
346
+ f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)),
347
+ f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)),
348
+ f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u))
349
+ );
350
+
351
+ let q_vec4 = q_lo2_vec4 - q_hi1_vec4;
352
+
353
+ let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu;
354
+ let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u;
355
+ let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
356
+
357
+ store_shmem_kquants(dl * q_vec4, elem_idx);
358
+ #elif INIT_SRC0_SHMEM_Q4_K
359
+ let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
360
+ let dm_byte_base = block_byte_base + 0u;
361
+ let scale_byte_base = block_byte_base + 4u;
362
+ let qs_byte_base = block_byte_base + 16u;
363
+
364
+ let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
365
+ let d = f16(dm[0]);
366
+ let dmin = f16(dm[1]);
367
+
368
+ let chunk = k_in_block / 64u;
369
+ let pos_in_chunk = (k_in_block % 64u) % 32u;
370
+ let sub_block = k_in_block / 32u;
371
+ let shift_phase = sub_block & 1u;
372
+
373
+ // whole 4 bits (4 elems)
374
+ let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
375
+ let qs_vec4 = vec4<f16>(
376
+ f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
377
+ f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
378
+ f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
379
+ f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
380
+ );
244
381
 
245
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
382
+ var sc: u32;
383
+ var mn: u32;
246
384
 
247
- let j_adjusted = j + (block_offset / 2u);
385
+ if (sub_block < 4u) {
386
+ let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
387
+ let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
388
+ sc = sc_byte & 63u;
389
+ mn = min_byte & 63u;
390
+ } else {
391
+ let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
392
+ let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
393
+ let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
394
+ sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
395
+ mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
396
+ }
248
397
 
398
+ let dl = d * f16(sc);
399
+ let ml = dmin * f16(mn);
249
400
 
250
- for (var k = 0u; k < 4u; k++) {
251
- let q_byte = get_byte(q_packed, k);
401
+ store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
402
+ #elif INIT_SRC0_SHMEM_Q5_K
403
+ let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
404
+ let dm_byte_base = block_byte_base + 0u;
405
+ let scale_byte_base = block_byte_base + 4u;
406
+ let qh_byte_base = block_byte_base + 16u;
407
+ let qs_byte_base = block_byte_base + 48u;
408
+
409
+ let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
410
+ let d = f16(dm[0]);
411
+ let dmin = f16(dm[1]);
412
+
413
+ let chunk = k_in_block / 64u;
414
+ let pos_in_chunk = (k_in_block % 64u) % 32u;
415
+ let sub_block = k_in_block / 32u;
416
+ let shift_phase = sub_block & 1u;
417
+
418
+ let qh_block = k_in_block % 32u;
419
+ let qh_shift_phase = sub_block;
420
+
421
+ // low 4 bits (4 elems)
422
+ let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
423
+ let qs_lo4_vec4 = vec4<f16>(
424
+ f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
425
+ f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
426
+ f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
427
+ f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
428
+ );
429
+
430
+ // high 1 bit (4 elems)
431
+ let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block);
432
+ let qh_vec4 = vec4<f16>(
433
+ f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)),
434
+ f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)),
435
+ f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)),
436
+ f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u))
437
+ );
252
438
 
253
- let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
254
- let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
255
- let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
256
- let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
439
+ var sc: u32;
440
+ var mn: u32;
257
441
 
258
- shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
259
- shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
260
- }
261
- }
442
+ if (sub_block < 4u) {
443
+ let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
444
+ let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
445
+ sc = sc_byte & 63u;
446
+ mn = min_byte & 63u;
447
+ } else {
448
+ let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
449
+ let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
450
+ let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
451
+ sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
452
+ mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
262
453
  }
454
+
455
+ let dl = d * f16(sc);
456
+ let ml = dmin * f16(mn);
457
+
458
+ store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
459
+ #elif INIT_SRC0_SHMEM_Q6_K
460
+ let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
461
+ let ql_byte_base = block_byte_base;
462
+ let qh_byte_base = block_byte_base + 128u;
463
+ let scales_byte_base = block_byte_base + 192u;
464
+ let d_byte_base = block_byte_base + 208u;
465
+
466
+ let d = load_f16_at_src0(d_byte_base);
467
+
468
+ let chunk = k_in_block / 128u;
469
+ let ql_pos_in_chunk = (k_in_block % 128u) % 64u;
470
+ let qh_pos_in_chunk = (k_in_block % 128u) % 32u;
471
+ let sub_block = k_in_block / 16u;
472
+ let ql_shift_phase = (k_in_block % 128u) / 64u;
473
+ let qh_shift_phase = (k_in_block % 128u) / 32u;
474
+
475
+ // low 4 bits (4 elems)
476
+ let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk);
477
+ let ql_lo4_vec4 = vec4<u32>(
478
+ (ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu,
479
+ (ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu,
480
+ (ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu,
481
+ (ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu
482
+ );
483
+
484
+ // hi 2 bits (4 elems)
485
+ let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk);
486
+ let qh_hi2_vec4 = vec4<u32>(
487
+ ((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u,
488
+ ((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u,
489
+ ((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u,
490
+ ((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u,
491
+ );
492
+
493
+ let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0);
494
+
495
+ let scale_byte = scales_byte_base + 1u * sub_block;
496
+ let scale_word = load_u32_at_src0_aligned(scale_byte);
497
+ let scale = get_byte_i32(scale_word, scale_byte & 3u);
498
+
499
+ store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
500
+ #endif
263
501
  }
264
502
  }
265
- #endif // INIT_SRC0_SHMEM_Q5_1
503
+ #endif // k-quants
266
504
 
267
- #ifdef INIT_SRC0_SHMEM_Q8_0
505
+ #ifdef INIT_SRC0_SHMEM_IQ4_NL
268
506
  const BLOCK_SIZE = 32u;
269
- // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
270
- override BLOCKS_K = TILE_K/BLOCK_SIZE;
271
- const NQ = 16u;
272
- const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
273
- const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
274
- const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
507
+ const BLOCK_SIZE_BYTES = 18u;
275
508
 
276
509
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
277
- for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
278
- let blck_idx = i / BLOCK_SIZE;
279
- let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
280
- let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
281
-
282
- let tile_m = blck_idx / BLOCKS_K;
510
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
511
+ let tile_m = elem_idx / TILE_K;
512
+ let tile_k = elem_idx % TILE_K;
283
513
  let global_m = offset_m + tile_m;
284
- let block_k = blck_idx % BLOCKS_K;
285
- let global_k = k_outer / BLOCK_SIZE + block_k;
514
+ let global_k = k_outer + tile_k;
515
+
516
+ if (global_m >= params.m || global_k >= params.k) {
517
+ shmem[elem_idx] = f16(0.0);
518
+ continue;
519
+ }
286
520
 
287
- if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
288
- let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
289
- let scale_idx = src0_idx * F16_PER_BLOCK;
290
- let d = src0[scale_idx];
521
+ let block_k = global_k / BLOCK_SIZE;
522
+ let k_in_block = global_k % BLOCK_SIZE;
291
523
 
292
- for (var j = 0u; j < F16_PER_THREAD; j+=2) {
293
- let q_0 = src0[scale_idx + 1u + block_offset + j];
294
- let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
524
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
525
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
526
+ let d = load_f16_at_src0(block_byte_base);
295
527
 
296
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
297
- for (var k = 0u; k < 4u; k++) {
298
- let q_byte = get_byte_i32(q_packed, k);
528
+ let pos = k_in_block % 16u;
529
+ let nib_shift = (k_in_block / 16u) * 4u;
530
+ let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u);
531
+ let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu;
299
532
 
300
- let q_val = f16(q_byte) * d;
301
- shmem[shmem_idx + j * 2 + k] = q_val;
302
- }
303
- }
304
- }
533
+ shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]);
305
534
  }
306
535
  }
307
- #endif // INIT_SRC0_SHMEM_Q8_0
536
+ #endif // INIT_SRC0_SHMEM_IQ4_NL
308
537
 
309
- #ifdef INIT_SRC0_SHMEM_Q8_1
310
- const BLOCK_SIZE = 32u;
311
- // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
312
- override BLOCKS_K = TILE_K/BLOCK_SIZE;
313
- const NQ = 16u;
314
- const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
315
- const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
316
- const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
538
+ #ifdef INIT_SRC0_SHMEM_IQ4_XS
539
+ const BLOCK_SIZE = 256u;
540
+ const BLOCK_SIZE_BYTES = 136u;
317
541
 
318
542
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
319
- for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
320
- let blck_idx = i / BLOCK_SIZE;
321
- let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
322
- let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
323
-
324
- let tile_m = blck_idx / BLOCKS_K;
543
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
544
+ let tile_m = elem_idx / TILE_K;
545
+ let tile_k = elem_idx % TILE_K;
325
546
  let global_m = offset_m + tile_m;
326
- let block_k = blck_idx % BLOCKS_K;
327
- let global_k = k_outer / BLOCK_SIZE + block_k;
547
+ let global_k = k_outer + tile_k;
548
+
549
+ if (global_m >= params.m || global_k >= params.k) {
550
+ shmem[elem_idx] = f16(0.0);
551
+ continue;
552
+ }
328
553
 
329
- if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
330
- let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
331
- let scale_idx = src0_idx * F16_PER_BLOCK;
332
- let d = src0[scale_idx];
333
- let m = src0[scale_idx + 1u];
554
+ let block_k = global_k / BLOCK_SIZE;
555
+ let k_in_block = global_k % BLOCK_SIZE;
334
556
 
335
- for (var j = 0u; j < F16_PER_THREAD; j+=2) {
336
- let q_0 = src0[scale_idx + 2u + block_offset + j];
337
- let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
557
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
558
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
338
559
 
339
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
340
- for (var k = 0u; k < 4u; k++) {
341
- let q_byte = get_byte_i32(q_packed, k);
560
+ let d_scales_h = load_u32_at_src0(block_byte_base);
561
+ let d = bitcast<vec2<f16>>(d_scales_h).x;
562
+ let scales_h = d_scales_h >> 16u;
342
563
 
343
- let q_val = f16(q_byte) * d + m;
344
- shmem[shmem_idx + j * 2 + k] = q_val;
345
- }
346
- }
347
- }
564
+ let ib = k_in_block / 32u;
565
+ let pos = k_in_block % 32u;
566
+
567
+ let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
568
+ let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu;
569
+ let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u;
570
+ let dl = d * f16(i32(ls_lo | ls_hi) - 32);
571
+
572
+ let iqs = ib * 16u + (pos % 16u);
573
+ let nib_shift = (pos / 16u) * 4u;
574
+ let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u);
575
+ let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu;
576
+
577
+ shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]);
348
578
  }
349
579
  }
350
- #endif // INIT_SRC0_SHMEM_Q8_1
580
+ #endif // INIT_SRC0_SHMEM_IQ4_XS
351
581
 
352
- #ifdef INIT_SRC0_SHMEM_Q2_K
582
+ #ifdef INIT_SRC0_SHMEM_IQ1_S
353
583
  const BLOCK_SIZE = 256u;
354
- const F16_PER_BLOCK = 42u;
584
+ const BLOCK_SIZE_BYTES = 50u;
355
585
 
356
586
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
357
- // Use standard thread layout instead of lane/row_group
358
587
  for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
359
588
  let tile_m = elem_idx / TILE_K;
360
589
  let tile_k = elem_idx % TILE_K;
361
-
362
590
  let global_m = offset_m + tile_m;
363
591
  let global_k = k_outer + tile_k;
364
592
 
@@ -367,56 +595,42 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
367
595
  continue;
368
596
  }
369
597
 
370
- let block_k = global_k / BLOCK_SIZE;
598
+ let block_k = global_k / BLOCK_SIZE;
371
599
  let k_in_block = global_k % BLOCK_SIZE;
372
600
 
373
- let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
374
- let scale_idx = src0_idx * F16_PER_BLOCK;
375
-
376
- let d = src0[scale_idx + 40u];
377
- let dmin = src0[scale_idx + 41u];
378
-
379
- // Decode the element at position k_in_block
380
- let block_of_32 = k_in_block / 32u;
381
- let pos_in_32 = k_in_block % 32u;
382
-
383
- let q_b_idx = (block_of_32 / 4u) * 32u;
384
- let shift = (block_of_32 % 4u) * 2u;
385
- let k = (pos_in_32 / 16u) * 16u;
386
- let l = pos_in_32 % 16u;
601
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
602
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
603
+ let d = load_f16_as_f32_at_src0(block_byte_base);
387
604
 
388
- let is = k_in_block / 16u;
605
+ let ib = k_in_block / 32u;
606
+ let pos = k_in_block % 32u;
607
+ let l = pos / 8u;
608
+ let j = pos % 8u;
389
609
 
390
- let sc_0 = src0[scale_idx + 2u * (is / 4u)];
391
- let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
392
- let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
393
- let sc = get_byte(sc_packed, is % 4u);
610
+ let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu;
611
+ let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0);
612
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
394
613
 
395
- let dl = d * f16(sc & 0xFu);
396
- let ml = dmin * f16(sc >> 4u);
614
+ let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
615
+ let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
397
616
 
398
- let q_idx = q_b_idx + k + l;
399
- let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
400
- let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
401
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
402
- let q_byte = get_byte(q_packed, q_idx % 4u);
403
- let qs_val = (q_byte >> shift) & 3u;
617
+ let gw = iq1_grid[(ig + j) / 16u];
618
+ let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
619
+ let gs = bitcast<i32>(g << 30u) >> 30u;
404
620
 
405
- let q_val = f16(qs_val) * dl - ml;
406
- shmem[elem_idx] = q_val;
621
+ shmem[elem_idx] = f16(dl * (f32(gs) + delta));
407
622
  }
408
623
  }
409
- #endif // INIT_SRC0_SHMEM_Q2_K
624
+ #endif // INIT_SRC0_SHMEM_IQ1_S
410
625
 
411
- #ifdef INIT_SRC0_SHMEM_Q3_K
626
+ #ifdef INIT_SRC0_SHMEM_IQ1_M
412
627
  const BLOCK_SIZE = 256u;
413
- const F16_PER_BLOCK = 55u;
628
+ const BLOCK_SIZE_BYTES = 56u;
414
629
 
415
630
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
416
631
  for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
417
632
  let tile_m = elem_idx / TILE_K;
418
633
  let tile_k = elem_idx % TILE_K;
419
-
420
634
  let global_m = offset_m + tile_m;
421
635
  let global_k = k_outer + tile_k;
422
636
 
@@ -425,90 +639,101 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
425
639
  continue;
426
640
  }
427
641
 
428
- let block_k = global_k / BLOCK_SIZE;
642
+ let block_k = global_k / BLOCK_SIZE;
429
643
  let k_in_block = global_k % BLOCK_SIZE;
430
644
 
431
- let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
432
- let scale_idx = src0_idx * F16_PER_BLOCK;
433
-
434
- let d = src0[scale_idx + 54u];
435
-
436
- // Load and unpack scales
437
- let kmask1: u32 = 0x03030303u;
438
- let kmask2: u32 = 0x0f0f0f0fu;
645
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
646
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
647
+
648
+ let scales0 = load_u32_at_src0(block_byte_base + 48u);
649
+ let scales1 = load_u32_at_src0(block_byte_base + 52u);
650
+ let scale_packed = ((scales0 >> 12u) & 0xFu) |
651
+ ((scales0 >> 24u) & 0x00F0u) |
652
+ ((scales1 >> 4u) & 0x0F00u) |
653
+ ((scales1 >> 16u) & 0xF000u);
654
+ let d = f32(bitcast<vec2<f16>>(scale_packed).x);
655
+
656
+ let ib = k_in_block / 32u;
657
+ let pos = k_in_block % 32u;
658
+ let l = pos / 8u;
659
+ let j = pos % 8u;
660
+
661
+ let scales = select(scales0, scales1, ib >= 4u);
662
+ let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu;
663
+ let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u;
664
+ let dl = d * f32(2u * s_pair + 1u);
665
+
666
+ let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u);
667
+ let qh = qh_word >> (16u * (ib % 2u));
668
+ let qh_nib = (qh >> (4u * l)) & 0xFu;
669
+
670
+ let qs_w = load_u32_at_src0(block_byte_base + ib * 4u);
671
+ let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u);
672
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u);
673
+
674
+ let ig = idx * 8u;
675
+ let gw = iq1_grid[(ig + j) / 16u];
676
+ let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u;
677
+ let gs = bitcast<i32>(g << 30u) >> 30u;
678
+
679
+ shmem[elem_idx] = f16(dl * (f32(gs) + delta));
680
+ }
681
+ }
682
+ #endif // INIT_SRC0_SHMEM_IQ1_M
439
683
 
440
- var scale_vals: array<u32, 4>;
441
- for (var i: u32 = 0u; i < 4u; i++) {
442
- let scale_0 = src0[scale_idx + 48u + (2u*i)];
443
- let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
444
- scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
445
- }
684
+ #ifdef INIT_SRC0_SHMEM_IQ2_XXS
685
+ const BLOCK_SIZE = 256u;
686
+ const BLOCK_SIZE_BYTES = 66u;
446
687
 
447
- var tmp: u32 = scale_vals[2];
448
- scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
449
- scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
450
- scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
451
- scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
452
-
453
- // Load hmask and qs arrays
454
- var hmask_vals: array<u32, 8>;
455
- for (var i: u32 = 0u; i < 8u; i++) {
456
- let hmask_0 = src0[scale_idx + (2u*i)];
457
- let hmask_1 = src0[scale_idx + (2u*i) + 1u];
458
- hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
459
- }
688
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
689
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
690
+ let tile_m = elem_idx / TILE_K;
691
+ let tile_k = elem_idx % TILE_K;
692
+ let global_m = offset_m + tile_m;
693
+ let global_k = k_outer + tile_k;
460
694
 
461
- var qs_vals: array<u32, 16>;
462
- for (var i: u32 = 0u; i < 16u; i++) {
463
- let qs_0 = src0[scale_idx + 16u + (2u*i)];
464
- let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
465
- qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
695
+ if (global_m >= params.m || global_k >= params.k) {
696
+ shmem[elem_idx] = f16(0.0);
697
+ continue;
466
698
  }
467
699
 
468
- let half = k_in_block / 128u; // 0 or 1
469
- let pos_in_half = k_in_block % 128u; // 0-127
470
- let shift_group = pos_in_half / 32u; // 0-3
471
- let pos_in_32 = pos_in_half % 32u; // 0-31
472
- let k_group = pos_in_32 / 16u; // 0 or 1
473
- let l = pos_in_32 % 16u; // 0-15
700
+ let block_k = global_k / BLOCK_SIZE;
701
+ let k_in_block = global_k % BLOCK_SIZE;
474
702
 
475
- let q_b_idx = half * 32u; // 0 or 32
476
- let shift = shift_group * 2u; // 0, 2, 4, 6
477
- let k = k_group * 16u; // 0 or 16
478
- let is = k_in_block / 16u; // 0-15
703
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
704
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
705
+ let d = load_f16_as_f32_at_src0(block_byte_base);
479
706
 
480
- // m increments every 32 elements across entire 256 element block
481
- let m_shift = k_in_block / 32u; // 0-7
482
- let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
707
+ let entry_idx = k_in_block / 8u;
708
+ let j = k_in_block % 8u;
483
709
 
484
- let sc = get_byte(scale_vals[is / 4u], is % 4u);
485
- let dl = d * (f16(sc) - 32.0);
710
+ let ib = entry_idx & ~3u;
711
+ let l = entry_idx & 3u;
486
712
 
487
- let q_idx = q_b_idx + k + l;
488
- let hm_idx = k + l;
713
+ let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u);
714
+ let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u);
715
+ let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25;
489
716
 
490
- let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
491
- let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
717
+ let ig = get_byte(aux0, l) * 8u;
718
+ let is = (aux1 >> (7u * l)) & 127u;
719
+ let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
492
720
 
493
- let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
494
- let qs_val = (q_byte >> shift) & 3u;
721
+ let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u);
722
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
495
723
 
496
- let q_val = (f16(qs_val) - f16(hm)) * dl;
497
- shmem[elem_idx] = q_val;
724
+ shmem[elem_idx] = f16(db * f32(g) * m);
498
725
  }
499
726
  }
727
+ #endif // INIT_SRC0_SHMEM_IQ2_XXS
500
728
 
501
- #endif // INIT_SRC0_SHMEM_Q3_K
502
-
503
- #ifdef INIT_SRC0_SHMEM_Q4_K
729
+ #ifdef INIT_SRC0_SHMEM_IQ2_XS
504
730
  const BLOCK_SIZE = 256u;
505
- const F16_PER_BLOCK = 72u;
731
+ const BLOCK_SIZE_BYTES = 74u;
506
732
 
507
733
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
508
734
  for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
509
735
  let tile_m = elem_idx / TILE_K;
510
736
  let tile_k = elem_idx % TILE_K;
511
-
512
737
  let global_m = offset_m + tile_m;
513
738
  let global_k = k_outer + tile_k;
514
739
 
@@ -517,78 +742,46 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
517
742
  continue;
518
743
  }
519
744
 
520
- let block_k = global_k / BLOCK_SIZE;
745
+ let block_k = global_k / BLOCK_SIZE;
521
746
  let k_in_block = global_k % BLOCK_SIZE;
522
747
 
523
- let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
524
- let scale_idx = src0_idx * F16_PER_BLOCK;
525
-
526
- let d = src0[scale_idx];
527
- let dmin = src0[scale_idx + 1u];
528
-
529
- // Load packed scales
530
- var scale_vals: array<u32, 3>;
531
- for (var i: u32 = 0u; i < 3u; i++) {
532
- let scale_0 = src0[scale_idx + 2u + (2u*i)];
533
- let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
534
- scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
535
- }
748
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
749
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
750
+ let d = load_f16_as_f32_at_src0(block_byte_base);
536
751
 
537
- // Map k_in_block to loop structure:
538
- // Outer loop over 64-element groups (alternating q_b_idx)
539
- // Inner loop over 2 shifts per group
540
- let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
541
- let pos_in_64 = k_in_block % 64u; // 0-63
542
- let shift_group = pos_in_64 / 32u; // 0 or 1
543
- let l = pos_in_64 % 32u; // 0-31
752
+ let entry_idx = k_in_block / 8u;
753
+ let j = k_in_block % 8u;
544
754
 
545
- let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
546
- let shift = shift_group * 4u; // 0 or 4
547
- let is = k_in_block / 32u; // 0-7
548
-
549
- var sc: u32;
550
- var mn: u32;
755
+ let ib = entry_idx & ~3u;
756
+ let l = entry_idx & 3u;
551
757
 
552
- if (is < 4u) {
553
- let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
554
- let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
555
- sc = sc_byte & 63u;
556
- mn = min_byte & 63u;
557
- } else {
558
- let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
559
- let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
560
- let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
561
-
562
- sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
563
- mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
564
- }
565
-
566
- let dl = d * f16(sc);
567
- let ml = dmin * f16(mn);
758
+ let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u);
759
+ let s = get_byte(scales_word, (ib % 16u) / 4u);
760
+ let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
761
+ let dl = d * (0.5 + f32(s_nib)) * 0.25;
568
762
 
569
- let q_idx = q_b_idx + l;
570
- let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
571
- let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
572
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
763
+ let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u);
764
+ let qs_val = qs_word & 0xFFFFu;
765
+ let ig = (qs_val & 511u) * 8u;
766
+ let is = qs_val >> 9u;
767
+ let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
573
768
 
574
- let q_byte = get_byte(q_packed, q_idx % 4u);
575
- let qs_val = (q_byte >> shift) & 0xFu;
769
+ let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u);
770
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
576
771
 
577
- let q_val = f16(qs_val) * dl - ml;
578
- shmem[elem_idx] = q_val;
772
+ shmem[elem_idx] = f16(dl * f32(g) * m);
579
773
  }
580
774
  }
581
- #endif // INIT_SRC0_SHMEM_Q4_K
775
+ #endif // INIT_SRC0_SHMEM_IQ2_XS
582
776
 
583
- #ifdef INIT_SRC0_SHMEM_Q5_K
777
+ #ifdef INIT_SRC0_SHMEM_IQ2_S
584
778
  const BLOCK_SIZE = 256u;
585
- const F16_PER_BLOCK = 88u;
779
+ const BLOCK_SIZE_BYTES = 82u;
586
780
 
587
781
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
588
782
  for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
589
783
  let tile_m = elem_idx / TILE_K;
590
784
  let tile_k = elem_idx % TILE_K;
591
-
592
785
  let global_m = offset_m + tile_m;
593
786
  let global_k = k_outer + tile_k;
594
787
 
@@ -597,91 +790,93 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
597
790
  continue;
598
791
  }
599
792
 
600
- let block_k = global_k / BLOCK_SIZE;
793
+ let block_k = global_k / BLOCK_SIZE;
601
794
  let k_in_block = global_k % BLOCK_SIZE;
602
795
 
603
- let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
604
- let scale_idx = src0_idx * F16_PER_BLOCK;
605
-
606
- let d = src0[scale_idx];
607
- let dmin = src0[scale_idx + 1u];
796
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
797
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
798
+ let d = load_f16_as_f32_at_src0(block_byte_base);
608
799
 
609
- // Load packed scales
610
- var scale_vals: array<u32, 3>;
611
- for (var i: u32 = 0u; i < 3u; i++) {
612
- let scale_0 = src0[scale_idx + 2u + (2u*i)];
613
- let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
614
- scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
615
- }
800
+ let ib = k_in_block / 32u;
801
+ let l = (k_in_block % 32u) / 8u;
802
+ let j = k_in_block % 8u;
616
803
 
617
- // The original loop processes elements in groups of 64
618
- // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
619
- // But u increments EVERY 32 elements (after each l loop)
620
- let group_of_64 = k_in_block / 64u; // 0-3
621
- let pos_in_64 = k_in_block % 64u; // 0-63
622
- let shift_group = pos_in_64 / 32u; // 0 or 1
623
- let l = pos_in_64 % 32u; // 0-31
804
+ let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u);
805
+ let s = get_byte(scales_word, ib % 4u);
806
+ let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u);
807
+ let dl = d * (0.5 + f32(s_nib)) * 0.25;
624
808
 
625
- let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
626
- let shift = shift_group * 4u; // 0 or 4
627
- let is = k_in_block / 32u; // 0-7
809
+ let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u);
810
+ let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u);
811
+ let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u;
812
+ let ig = (get_byte(qs_word, l) | qh_b) * 8u;
628
813
 
629
- // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
630
- let u_shift = k_in_block / 32u; // 0-7
631
- let u: u32 = 1u << u_shift;
632
-
633
- var sc: u32;
634
- var mn: u32;
814
+ let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u);
815
+ let signs = get_byte(signs_word, l);
635
816
 
636
- if (is < 4u) {
637
- let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
638
- let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
639
- sc = sc_byte & 63u;
640
- mn = min_byte & 63u;
641
- } else {
642
- let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
643
- let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
644
- let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
645
-
646
- sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
647
- mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
648
- }
649
-
650
- let dl = d * f16(sc);
651
- let ml = dmin * f16(mn);
817
+ let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u);
818
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u);
652
819
 
653
- let q_idx = q_b_idx + l;
654
- let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
655
- let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
656
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
820
+ shmem[elem_idx] = f16(dl * f32(g) * m);
821
+ }
822
+ }
823
+ #endif // INIT_SRC0_SHMEM_IQ2_S
657
824
 
658
- let q_byte = get_byte(q_packed, q_idx % 4u);
825
+ #ifdef INIT_SRC0_SHMEM_IQ3_XXS
826
+ const BLOCK_SIZE = 256u;
827
+ const BLOCK_SIZE_BYTES = 98u;
659
828
 
660
- let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
661
- let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
662
- let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
829
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
830
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
831
+ let tile_m = elem_idx / TILE_K;
832
+ let tile_k = elem_idx % TILE_K;
833
+ let global_m = offset_m + tile_m;
834
+ let global_k = k_outer + tile_k;
663
835
 
664
- let qh_byte = get_byte(qh_packed, l % 4u);
836
+ if (global_m >= params.m || global_k >= params.k) {
837
+ shmem[elem_idx] = f16(0.0);
838
+ continue;
839
+ }
665
840
 
666
- let qs_val = (q_byte >> shift) & 0xFu;
667
- let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
841
+ let block_k = global_k / BLOCK_SIZE;
842
+ let k_in_block = global_k % BLOCK_SIZE;
668
843
 
669
- let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
670
- shmem[elem_idx] = q_val;
844
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
845
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
846
+ let d = load_f16_as_f32_at_src0(block_byte_base);
847
+
848
+ let ib_pair = k_in_block / 32u;
849
+ let in_pair = k_in_block % 32u;
850
+ let l = in_pair / 8u;
851
+ let in_l = in_pair % 8u;
852
+ let k2 = in_l / 4u;
853
+ let j = in_l % 4u;
854
+
855
+ let ib = ib_pair * 2u;
856
+ let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u;
857
+ let sc_sign = load_u32_at_src0(sc_sign_off);
858
+ let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5;
859
+ let is = (sc_sign >> (7u * l)) & 127u;
860
+ let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u);
861
+
862
+ let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu;
863
+ let ig_byte = get_byte(ig_word, k2);
864
+ let g = get_byte(iq3xxs_grid[ig_byte], j);
865
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
866
+
867
+ shmem[elem_idx] = f16(db * f32(g) * m);
671
868
  }
672
869
  }
870
+ #endif // INIT_SRC0_SHMEM_IQ3_XXS
673
871
 
674
- #endif // INIT_SRC0_SHMEM_Q5_K
675
-
676
- #ifdef INIT_SRC0_SHMEM_Q6_K
872
+ #ifdef INIT_SRC0_SHMEM_IQ3_S
677
873
  const BLOCK_SIZE = 256u;
678
- const F16_PER_BLOCK = 105u;
874
+ const BLOCK_SIZE_BYTES = 110u;
679
875
 
680
876
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
681
877
  for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
682
878
  let tile_m = elem_idx / TILE_K;
683
879
  let tile_k = elem_idx % TILE_K;
684
-
685
880
  let global_m = offset_m + tile_m;
686
881
  let global_k = k_outer + tile_k;
687
882
 
@@ -690,77 +885,42 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
690
885
  continue;
691
886
  }
692
887
 
693
- let block_k = global_k / BLOCK_SIZE;
888
+ let block_k = global_k / BLOCK_SIZE;
694
889
  let k_in_block = global_k % BLOCK_SIZE;
695
890
 
696
- let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
697
- let scale_idx = src0_idx * F16_PER_BLOCK;
698
-
699
- let half = k_in_block / 128u;
700
- let pos_in_half = k_in_block % 128u;
701
- let quarter = pos_in_half / 32u;
702
- let l = pos_in_half % 32u;
703
-
704
- let ql_b_idx = half * 64u;
705
- let qh_b_idx = half * 32u;
706
- let sc_b_idx = half * 8u;
707
-
708
- // Load only ql13 word needed
709
- let ql13_flat = ql_b_idx + l;
710
- let ql13_word = ql13_flat / 4u;
711
- let ql13 = bitcast<u32>(vec2(
712
- src0[scale_idx + 2u * ql13_word],
713
- src0[scale_idx + 2u * ql13_word + 1u]
714
- ));
715
- let ql13_b = get_byte(ql13, ql13_flat % 4u);
716
-
717
- // Load only ql24 word needed
718
- let ql24_flat = ql_b_idx + l + 32u;
719
- let ql24_word = ql24_flat / 4u;
720
- let ql24 = bitcast<u32>(vec2(
721
- src0[scale_idx + 2u * ql24_word],
722
- src0[scale_idx + 2u * ql24_word + 1u]
723
- ));
724
- let ql24_b = get_byte(ql24, ql24_flat % 4u);
725
-
726
- // Load only qh word needed
727
- let qh_flat = qh_b_idx + l;
728
- let qh_word = qh_flat / 4u;
729
- let qh = bitcast<u32>(vec2(
730
- src0[scale_idx + 64u + 2u * qh_word],
731
- src0[scale_idx + 64u + 2u * qh_word + 1u]
732
- ));
733
- let qh_b = get_byte(qh, qh_flat % 4u);
734
-
735
- let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
736
- let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
737
- let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
738
- let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
739
-
740
- // Load only the scale word needed
741
- let is = l / 16u;
742
- let sc_idx = sc_b_idx + is + quarter * 2u;
743
- let sc_word = sc_idx / 4u;
744
- let sc = bitcast<u32>(vec2(
745
- src0[scale_idx + 96u + 2u * sc_word],
746
- src0[scale_idx + 96u + 2u * sc_word + 1u]
747
- ));
748
- let sc_val = get_byte_i32(sc, sc_idx % 4u);
749
-
750
- let d = src0[scale_idx + 104u];
751
-
752
- var q_val: f16;
753
- if (quarter == 0u) {
754
- q_val = q1;
755
- } else if (quarter == 1u) {
756
- q_val = q2;
757
- } else if (quarter == 2u) {
758
- q_val = q3;
759
- } else {
760
- q_val = q4;
761
- }
891
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
892
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
893
+ let d = load_f16_as_f32_at_src0(block_byte_base);
894
+
895
+ let ib = k_in_block / 64u;
896
+ let rest = k_in_block % 64u;
897
+ let k = rest / 32u;
898
+ let in_k = rest % 32u;
899
+ let l = in_k / 8u;
900
+ let in_l = in_k % 8u;
901
+ let k2 = in_l / 4u;
902
+ let j = in_l % 4u;
903
+
904
+ let scales_word = load_u32_at_src0(block_byte_base + 106u);
905
+ let s = get_byte(scales_word, ib);
906
+ let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u);
907
+ let dl = d * (1.0 + 2.0 * f32(s_nib));
908
+
909
+ let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u);
910
+ let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k);
911
+
912
+ let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu;
913
+ let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u);
914
+ let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u);
915
+ let ig = select(ig_lo, ig_hi, k2 != 0u);
916
+
917
+ let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u);
918
+ let signs = get_byte(signs_word, l);
919
+
920
+ let g = get_byte(iq3s_grid[ig], j);
921
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u);
762
922
 
763
- shmem[elem_idx] = d * f16(sc_val) * q_val;
923
+ shmem[elem_idx] = f16(dl * f32(g) * m);
764
924
  }
765
925
  }
766
- #endif // INIT_SRC0_SHMEM_Q6_K
926
+ #endif // INIT_SRC0_SHMEM_IQ3_S