whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,556 @@
1
+ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
2
+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
3
+ #if LOAD_VEC_A == 8
4
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
5
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
6
+ FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
7
+ buf_a[buf_idx ] = aa[0].xy;
8
+ buf_a[buf_idx + 1] = aa[0].zw;
9
+ buf_a[buf_idx + 2] = aa[1].xy;
10
+ buf_a[buf_idx + 3] = aa[1].zw;
11
+ #elif LOAD_VEC_A == 4
12
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
13
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
14
+ FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
15
+ buf_a[buf_idx ] = aa.xy;
16
+ buf_a[buf_idx + 1] = aa.zw;
17
+ #else // LOAD_VEC_BATCH_A == 2
18
+ const uint idx = pos_a + col * p.stride_a + row * 2;
19
+ const uint buf_idx = col * SHMEM_STRIDE + row;
20
+ if (idx_m < p.M && block + row * 2 + 1 < end_k) {
21
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
22
+ data_a[idx + 1]);
23
+ } else if (idx_m < p.M && block + row * 2 < end_k) {
24
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
25
+ } else {
26
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
27
+ }
28
+ #endif
29
+ #elif defined(DATA_A_BF16)
30
+ #if LOAD_VEC_A == 4
31
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
32
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
33
+ FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
34
+ buf_a[buf_idx ] = aa.xy;
35
+ buf_a[buf_idx + 1] = aa.zw;
36
+ #else // LOAD_VEC_BATCH_A == 2
37
+ const uint idx = pos_a + col * p.stride_a + row * 2;
38
+ const uint buf_idx = col * SHMEM_STRIDE + row;
39
+ if (idx_m < p.M && block + row * 2 + 1 < end_k) {
40
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
41
+ TO_FLOAT_TYPE(data_a[idx + 1]));
42
+ } else if (idx_m < p.M && block + row * 2 < end_k) {
43
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
44
+ } else {
45
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
46
+ }
47
+ #endif
48
+ #elif defined(DATA_A_Q4_0)
49
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
50
+ const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
51
+
52
+ const uint ib = idx / 4;
53
+ const uint iqs = idx & 0x03;
54
+
55
+ const float d = float(data_a_packed16[ib].d);
56
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
57
+ const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
58
+ const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
59
+
60
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
61
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
62
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
63
+ buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
64
+ #elif defined(DATA_A_Q4_1)
65
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
66
+ const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
67
+
68
+ const uint ib = idx / 4;
69
+ const uint iqs = idx & 0x03;
70
+
71
+ const float d = float(data_a_packed16[ib].d);
72
+ const float m = float(data_a_packed16[ib].m);
73
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
74
+ const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
75
+ const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
76
+
77
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
78
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
79
+ buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
80
+ buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
81
+ #elif defined(DATA_A_Q5_0)
82
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
83
+ const uint buf_idx = col * SHMEM_STRIDE + row;
84
+
85
+ const uint ib = idx / 8;
86
+ const uint iqs = idx & 0x07;
87
+
88
+ const float d = float(data_a_packed16[ib].d);
89
+ const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
90
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
91
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
92
+
93
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
94
+ const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
95
+
96
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
97
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
98
+ #elif defined(DATA_A_Q5_1)
99
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
100
+ const uint buf_idx = col * SHMEM_STRIDE + row;
101
+
102
+ const uint ib = idx / 8;
103
+ const uint iqs = idx & 0x07;
104
+
105
+ const float d = float(data_a_packed16[ib].d);
106
+ const float m = float(data_a_packed16[ib].m);
107
+ const uint uint_qh = data_a_packed16[ib].qh;
108
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
109
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
110
+
111
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
112
+ const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
113
+
114
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
115
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
116
+ #elif defined(DATA_A_Q8_0)
117
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
118
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
119
+
120
+ const uint ib = idx / 8;
121
+ const uint iqs = idx & 0x07;
122
+
123
+ const float d = float(data_a_packed16[ib].d);
124
+ const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
125
+ const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
126
+ const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
127
+
128
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
129
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
130
+ #elif defined(DATA_A_Q2_K)
131
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
132
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
133
+
134
+ const uint ib = idx / 128; // 2 values per idx
135
+ const uint iqs = idx % 128; // 0..127
136
+
137
+ const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
138
+ const uint scalesi = iqs / 8; // 0..15
139
+ const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
140
+
141
+ const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
142
+ const uint scales = data_a[ib].scales[scalesi];
143
+ const vec2 d = vec2(data_a[ib].d);
144
+
145
+ const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
146
+
147
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
148
+ #elif defined(DATA_A_Q3_K)
149
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
150
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
151
+
152
+ const uint ib = idx / 128; // 2 values per idx
153
+ const uint iqs = idx % 128; // 0..127
154
+
155
+ const uint n = iqs / 64; // 0,1
156
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
157
+ const uint hmi = (iqs % 16) * 2; // 0,2,4..30
158
+ const uint j = (iqs % 64) / 4; // 0..3
159
+ const uint is = iqs / 8; // 0..15
160
+ const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
161
+ const uint qsshift = halfsplit * 2; // 0,2,4,6
162
+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
163
+
164
+ const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
165
+ | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
166
+ const float dl = float(data_a[ib].d) * float(us - 32);
167
+
168
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
169
+ dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
170
+ #elif defined(DATA_A_Q4_K)
171
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
172
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
173
+
174
+ const uint ib = idx / 128; // 2 values per idx
175
+ const uint iqs = idx % 128; // 0..127
176
+
177
+ const uint n = iqs / 32; // 0,1,2,3
178
+ const uint b = (iqs % 32) / 16; // 0,1
179
+ const uint is = 2 * n + b; // 0..7
180
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
181
+
182
+ const vec2 loadd = vec2(data_a[ib].d);
183
+
184
+ const uint scidx0 = (is < 4) ? is : (is + 4);
185
+ const uint scidx1 = (is < 4) ? is : (is - 4);
186
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
187
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
188
+ const uint mbidx0 = is + 4;
189
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
190
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
191
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
192
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
193
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
194
+
195
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
196
+ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
197
+
198
+ const float d = loadd.x * sc;
199
+ const float m = -loadd.y * mbyte;
200
+
201
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
202
+ fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
203
+ #elif defined(DATA_A_Q5_K)
204
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
205
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
206
+
207
+ const uint ib = idx / 128; // 2 values per idx
208
+ const uint iqs = idx % 128; // 0..127
209
+
210
+ const uint n = iqs / 32; // 0,1,2,3
211
+ const uint b = (iqs % 32) / 16; // 0,1
212
+ const uint is = 2 * n + b; // 0..7
213
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
214
+ const uint qhi = (iqs % 16) * 2; // 0,2,4..30
215
+
216
+ const uint8_t hm = uint8_t(1 << (iqs / 16));
217
+
218
+ const vec2 loadd = vec2(data_a[ib].d);
219
+
220
+ const uint scidx0 = (is < 4) ? is : (is + 4);
221
+ const uint scidx1 = (is < 4) ? is : (is - 4);
222
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
223
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
224
+ const uint mbidx0 = is + 4;
225
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
226
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
227
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
228
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
229
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
230
+
231
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
232
+ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
233
+
234
+ const float d = loadd.x * sc;
235
+ const float m = -loadd.y * mbyte;
236
+
237
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
238
+ fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
239
+ #elif defined(DATA_A_Q6_K)
240
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
241
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
242
+
243
+ const uint ib = idx / 128; // 2 values per idx
244
+ const uint iqs = idx % 128; // 0..127
245
+
246
+ const uint n = iqs / 64; // 0,1
247
+ const uint b = (iqs % 64) / 32; // 0,1
248
+ const uint is_b = (iqs % 16) / 8; // 0,1
249
+ const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
250
+ const uint is = 8 * n + qhshift + is_b; // 0..15
251
+ const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
252
+ const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
253
+
254
+ const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
255
+
256
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
257
+ dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
258
+ #elif defined(DATA_A_IQ1_S)
259
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
260
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
261
+
262
+ const uint ib = idx / 32; // 8 values per idx
263
+ const uint ib32 = (idx % 32) / 4; // 0..7
264
+ const uint ib8 = idx % 32;
265
+
266
+ const float d = float(data_a[ib].d);
267
+ const uint qh = data_a[ib].qh[ib32];
268
+ const uint qs = data_a[ib].qs[ib8];
269
+ const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
270
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
271
+ const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
272
+
273
+ [[unroll]] for (int k = 0; k < 4; ++k) {
274
+ buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
275
+ dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
276
+ }
277
+ #elif defined(DATA_A_IQ1_M)
278
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
279
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
280
+
281
+ const uint ib = idx / 32; // 8 values per idx
282
+ const uint ib8 = idx % 32;
283
+ const uint ib16 = ib8 / 2;
284
+
285
+ const uint16_t[4] scales = data_a[ib].scales;
286
+ const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
287
+ const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
288
+ const uint sc = scales[ib8 / 8];
289
+ const uint qs = data_a[ib].qs[ib8];
290
+ const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
291
+ const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
292
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
293
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
294
+
295
+ [[unroll]] for (int k = 0; k < 4; ++k) {
296
+ buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
297
+ dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
298
+ }
299
+ #elif defined(DATA_A_IQ2_XXS)
300
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
301
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
302
+
303
+ const uint ib = idx / 32; // 8 values per idx
304
+ const uint ib32 = (idx % 32) / 4; // 0..7
305
+ const uint ib8 = idx % 4;
306
+
307
+ const float d = float(data_a[ib].d);
308
+ const uint qs = data_a[ib].qs[8 * ib32 + ib8];
309
+ const uint signs = pack32(u8vec4(
310
+ data_a[ib].qs[8*ib32 + 4],
311
+ data_a[ib].qs[8*ib32 + 5],
312
+ data_a[ib].qs[8*ib32 + 6],
313
+ data_a[ib].qs[8*ib32 + 7]
314
+ ));
315
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
316
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
317
+ const uint sign = sign7 | (bitCount(sign7) << 7);
318
+ const uvec2 grid = iq2xxs_grid[qs];
319
+ const vec4 grid0 = vec4(unpack8(grid.x));
320
+ const vec4 grid1 = vec4(unpack8(grid.y));
321
+
322
+ buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
323
+ (sign & 2) != 0 ? -grid0.y : grid0.y);
324
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
325
+ (sign & 8) != 0 ? -grid0.w : grid0.w);
326
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
327
+ (sign & 32) != 0 ? -grid1.y : grid1.y);
328
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
329
+ (sign & 128) != 0 ? -grid1.w : grid1.w);
330
+ #elif defined(DATA_A_IQ2_XS)
331
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
332
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
333
+
334
+ const uint ib = idx / 32; // 8 values per idx
335
+ const uint ib32 = (idx % 32) / 4; // 0..7
336
+ const uint ib8 = idx % 4; // 0..3
337
+
338
+ const float d = float(data_a[ib].d);
339
+ const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
340
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
341
+ const uint qs = data_a[ib].qs[4 * ib32 + ib8];
342
+ const uint sign7 = qs >> 9;
343
+ const uint sign = sign7 | (bitCount(sign7) << 7);
344
+ const uvec2 grid = iq2xs_grid[qs & 511];
345
+ const vec4 grid0 = vec4(unpack8(grid.x));
346
+ const vec4 grid1 = vec4(unpack8(grid.y));
347
+
348
+ buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
349
+ (sign & 2) != 0 ? -grid0.y : grid0.y);
350
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
351
+ (sign & 8) != 0 ? -grid0.w : grid0.w);
352
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
353
+ (sign & 32) != 0 ? -grid1.y : grid1.y);
354
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
355
+ (sign & 128) != 0 ? -grid1.w : grid1.w);
356
+ #elif defined(DATA_A_IQ2_S)
357
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
358
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
359
+
360
+ const uint ib = idx / 32; // 8 values per idx
361
+ const uint ib8 = idx % 32; // 0..31
362
+ const uint ib32 = ib8 / 4; // 0..7
363
+
364
+ const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
365
+ const uint qs = data_a[ib].qs[ib8];
366
+ const uint qh = data_a[ib].qh[ib32];
367
+ const uint qhshift = 2 * (ib8 % 4);
368
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
369
+
370
+ const float d = float(data_a[ib].d);
371
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
372
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
373
+ const vec4 grid0 = vec4(unpack8(grid.x));
374
+ const vec4 grid1 = vec4(unpack8(grid.y));
375
+
376
+ buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
377
+ (sign & 2) != 0 ? -grid0.y : grid0.y);
378
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
379
+ (sign & 8) != 0 ? -grid0.w : grid0.w);
380
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
381
+ (sign & 32) != 0 ? -grid1.y : grid1.y);
382
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
383
+ (sign & 128) != 0 ? -grid1.w : grid1.w);
384
+ #elif defined(DATA_A_IQ3_XXS)
385
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
386
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
387
+
388
+ const uint ib = idx / 64; // 4 values per idx
389
+ const uint iqs = idx % 64; // 0..63
390
+ const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
391
+
392
+ const float d = float(data_a[ib].d);
393
+ const uint qs = data_a[ib].qs[iqs];
394
+ const uint signs = pack32(u8vec4(
395
+ data_a[ib].qs[is+0],
396
+ data_a[ib].qs[is+1],
397
+ data_a[ib].qs[is+2],
398
+ data_a[ib].qs[is+3]
399
+ ));
400
+ const float db = d * 0.5 * (0.5 + (signs >> 28));
401
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
402
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
403
+ const uint grid = iq3xxs_grid[qs];
404
+ const vec4 v = db * vec4(unpack8(grid));
405
+
406
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
407
+ (sign & 2) != 0 ? -v.y : v.y);
408
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
409
+ (sign & 8) != 0 ? -v.w : v.w);
410
+ #elif defined(DATA_A_IQ3_S)
411
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
412
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
413
+
414
+ const uint ib = idx / 64; // 4 values per idx
415
+ const uint iqs = idx % 64; // 0..63
416
+ const uint iqh = iqs / 8;
417
+
418
+ const float d = float(data_a[ib].d);
419
+ const uint qs = data_a[ib].qs[iqs];
420
+ const uint qh = data_a[ib].qh[iqh];
421
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
422
+ const uint scale = data_a[ib].scales[iqs / 16];
423
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
424
+ const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
425
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
426
+ const vec4 v = db * vec4(unpack8(grid));
427
+
428
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
429
+ (sign & 2) != 0 ? -v.y : v.y);
430
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
431
+ (sign & 8) != 0 ? -v.w : v.w);
432
+ #elif defined(DATA_A_IQ4_XS)
433
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
434
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
435
+
436
+ const uint ib = idx / 128; // 2 values per idx
437
+ const uint ib32 = (idx % 128) / 16; // 0..7
438
+ const uint iq = 16 * ib32 + 2 * (idx % 8);
439
+
440
+ const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
441
+ const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
442
+ const uint qshift = (idx & 8) >> 1;
443
+ u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
444
+ qs = (qs >> qshift) & uint8_t(0xF);
445
+
446
+ const float d = float(data_a[ib].d);
447
+ const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
448
+
449
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
450
+ #elif defined(DATA_A_IQ4_NL)
451
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
452
+ const uint buf_idx = col * SHMEM_STRIDE + row;
453
+
454
+ const uint ib = idx / 8;
455
+ const uint iqs = idx & 0x07;
456
+
457
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
458
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
459
+
460
+ buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
461
+ kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
462
+ buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
463
+ kvalues_iq4nl[vui >> 12]);
464
+ #elif defined(DATA_A_MXFP4)
465
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
466
+ const uint buf_idx = col * SHMEM_STRIDE + row;
467
+
468
+ const uint ib = idx / 8;
469
+ const uint iqs = (idx & 0x07) * 2;
470
+
471
+ const float d = e8m0_to_fp32(data_a[ib].e);
472
+ const uint vui = uint(data_a[ib].qs[iqs]);
473
+ const uint vui2 = uint(data_a[ib].qs[iqs+1]);
474
+
475
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
476
+ kvalues_mxfp4[vui2 & 0xF] * d);
477
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
478
+ kvalues_mxfp4[vui2 >> 4] * d);
479
+ #endif
480
+ }
481
+
482
+ #if !defined(MUL_MAT_ID)
483
+ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
484
+ #if LOAD_VEC_B == 8
485
+ // Not supported for b_type bf16 because bf16mat2x4 does not exist
486
+ const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
487
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
488
+ FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
489
+ buf_b[buf_idx + 0] = bb[0].xy;
490
+ buf_b[buf_idx + 1] = bb[0].zw;
491
+ buf_b[buf_idx + 2] = bb[1].xy;
492
+ buf_b[buf_idx + 3] = bb[1].zw;
493
+ #elif LOAD_VEC_B == 4
494
+ const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
495
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
496
+ #if defined(DATA_B_BF16)
497
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
498
+ #else
499
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
500
+ #endif
501
+ buf_b[buf_idx + 0] = bb.xy;
502
+ buf_b[buf_idx + 1] = bb.zw;
503
+ #else // LOAD_VEC_BATCH_B == 2
504
+ const uint idx = pos_b + col * p.stride_b + row * 2;
505
+ const uint buf_idx = col * SHMEM_STRIDE + row;
506
+ if (idx_n < p.N && block + row * 2 + 1 < end_k) {
507
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
508
+ TO_FLOAT_TYPE(data_b[idx + 1]));
509
+ } else if (idx_n < p.N && block + row * 2 < end_k) {
510
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
511
+ } else {
512
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
513
+ }
514
+ #endif
515
+ }
516
+ #else
517
+ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
518
+ #if LOAD_VEC_B == 8
519
+ // Not supported for b_type bf16 because bf16mat2x4 does not exist
520
+ const u16vec2 row_idx = row_ids[col];
521
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
522
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
523
+ FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
524
+ buf_b[buf_idx + 0] = bb[0].xy;
525
+ buf_b[buf_idx + 1] = bb[0].zw;
526
+ buf_b[buf_idx + 2] = bb[1].xy;
527
+ buf_b[buf_idx + 3] = bb[1].zw;
528
+ #elif LOAD_VEC_B == 4
529
+ const u16vec2 row_idx = row_ids[col];
530
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
531
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
532
+ #if defined(DATA_B_BF16)
533
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
534
+ #else
535
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
536
+ #endif
537
+ buf_b[buf_idx + 0] = bb.xy;
538
+ buf_b[buf_idx + 1] = bb.zw;
539
+ #else // LOAD_VEC_BATCH_B == 2
540
+ const uint row_i = ic * BN + col;
541
+ const uint buf_idx = col * SHMEM_STRIDE + row;
542
+ if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
543
+ const u16vec2 row_idx = row_ids[col];
544
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
545
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
546
+ TO_FLOAT_TYPE(data_b[idx + 1]));
547
+ } else if (row_i < _ne1 && block + row * 2 < end_k) {
548
+ const u16vec2 row_idx = row_ids[col];
549
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
550
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
551
+ } else {
552
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
553
+ }
554
+ #endif
555
+ }
556
+ #endif
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
28
28
  #if defined(A_TYPE_PACKED32)
29
29
  layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
30
30
  #endif
31
- layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
31
+ layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
32
32
  layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
33
33
 
34
34
  #ifdef MUL_MAT_ID
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
98
98
  #endif
99
99
 
100
100
  #define LOAD_VEC_A (4 * QUANT_R)
101
- #define LOAD_VEC_B 4
101
+ #define LOAD_VEC_B 16
102
102
 
103
103
  #ifdef MUL_MAT_ID
104
104
  shared u16vec2 row_ids[4096];
@@ -270,15 +270,22 @@ void main() {
270
270
  const uint iqs = idx & 0x7;
271
271
  #else
272
272
  const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
273
+ const uint ib_outer = ib / 4;
274
+ const uint ib_inner = ib % 4;
275
+
273
276
  const uint iqs = loadr_b;
274
277
  #endif
275
278
 
276
279
  const uint buf_ib = loadc_b + l;
277
280
 
278
281
  if (iqs == 0) {
279
- buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282
+ buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280
283
  }
281
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284
+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
285
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
286
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
287
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
288
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
282
289
  }
283
290
 
284
291
  barrier();
@@ -349,7 +356,7 @@ void main() {
349
356
  cache_b_qs[cc * (BK / 4) + idx_k]);
350
357
  }
351
358
 
352
- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
359
+ sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
353
360
  }
354
361
  }
355
362
  }
@@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
16
16
  (vui >> 4) & 0x0F0F0F0F);
17
17
  }
18
18
 
19
- ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
20
- return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
19
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
20
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
21
21
  }
22
22
  #endif
23
23
 
@@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
29
29
  (vui >> 4) & 0x0F0F0F0F);
30
30
  }
31
31
 
32
- ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
33
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
32
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
33
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
34
34
  }
35
35
  #endif
36
36
 
@@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
50
50
  return i32vec2(v0, v1);
51
51
  }
52
52
 
53
- ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
54
- return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
53
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
54
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
55
55
  }
56
56
  #endif
57
57
 
@@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
69
69
  return i32vec2(v0, v1);
70
70
  }
71
71
 
72
- ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
73
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
72
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
73
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
74
74
  }
75
75
  #endif
76
76
 
@@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
81
81
  data_a[ib].qs[iqs * 2 + 1]));
82
82
  }
83
83
 
84
- ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
84
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
85
85
  return ACC_TYPE(float(q_sum) * da * dsb.x);
86
86
  }
87
87
  #endif
@@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
92
92
  }
93
93
  #endif
94
94
 
95
+ #if defined(DATA_A_MXFP4)
96
+ FLOAT_TYPE get_d(uint ib) {
97
+ return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
98
+ }
99
+ #endif
100
+
95
101
  #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
96
102
  FLOAT_TYPE_VEC2 get_dm(uint ib) {
97
103
  return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);