whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -0,0 +1,556 @@
1
+ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
2
+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
3
+ #if LOAD_VEC_A == 8
4
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
5
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
6
+ FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
7
+ buf_a[buf_idx ] = aa[0].xy;
8
+ buf_a[buf_idx + 1] = aa[0].zw;
9
+ buf_a[buf_idx + 2] = aa[1].xy;
10
+ buf_a[buf_idx + 3] = aa[1].zw;
11
+ #elif LOAD_VEC_A == 4
12
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
13
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
14
+ FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
15
+ buf_a[buf_idx ] = aa.xy;
16
+ buf_a[buf_idx + 1] = aa.zw;
17
+ #else // LOAD_VEC_BATCH_A == 2
18
+ const uint idx = pos_a + col * p.stride_a + row * 2;
19
+ const uint buf_idx = col * SHMEM_STRIDE + row;
20
+ if (idx_m < p.M && block + row * 2 + 1 < end_k) {
21
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
22
+ data_a[idx + 1]);
23
+ } else if (idx_m < p.M && block + row * 2 < end_k) {
24
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
25
+ } else {
26
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
27
+ }
28
+ #endif
29
+ #elif defined(DATA_A_BF16)
30
+ #if LOAD_VEC_A == 4
31
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
32
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
33
+ FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
34
+ buf_a[buf_idx ] = aa.xy;
35
+ buf_a[buf_idx + 1] = aa.zw;
36
+ #else // LOAD_VEC_BATCH_A == 2
37
+ const uint idx = pos_a + col * p.stride_a + row * 2;
38
+ const uint buf_idx = col * SHMEM_STRIDE + row;
39
+ if (idx_m < p.M && block + row * 2 + 1 < end_k) {
40
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
41
+ TO_FLOAT_TYPE(data_a[idx + 1]));
42
+ } else if (idx_m < p.M && block + row * 2 < end_k) {
43
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
44
+ } else {
45
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
46
+ }
47
+ #endif
48
+ #elif defined(DATA_A_Q4_0)
49
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
50
+ const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
51
+
52
+ const uint ib = idx / 4;
53
+ const uint iqs = idx & 0x03;
54
+
55
+ const float d = float(data_a_packed16[ib].d);
56
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
57
+ const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
58
+ const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
59
+
60
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
61
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
62
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
63
+ buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
64
+ #elif defined(DATA_A_Q4_1)
65
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
66
+ const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
67
+
68
+ const uint ib = idx / 4;
69
+ const uint iqs = idx & 0x03;
70
+
71
+ const float d = float(data_a_packed16[ib].d);
72
+ const float m = float(data_a_packed16[ib].m);
73
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
74
+ const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
75
+ const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
76
+
77
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
78
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
79
+ buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
80
+ buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
81
+ #elif defined(DATA_A_Q5_0)
82
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
83
+ const uint buf_idx = col * SHMEM_STRIDE + row;
84
+
85
+ const uint ib = idx / 8;
86
+ const uint iqs = idx & 0x07;
87
+
88
+ const float d = float(data_a_packed16[ib].d);
89
+ const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
90
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
91
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
92
+
93
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
94
+ const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
95
+
96
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
97
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
98
+ #elif defined(DATA_A_Q5_1)
99
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
100
+ const uint buf_idx = col * SHMEM_STRIDE + row;
101
+
102
+ const uint ib = idx / 8;
103
+ const uint iqs = idx & 0x07;
104
+
105
+ const float d = float(data_a_packed16[ib].d);
106
+ const float m = float(data_a_packed16[ib].m);
107
+ const uint uint_qh = data_a_packed16[ib].qh;
108
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
109
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
110
+
111
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
112
+ const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
113
+
114
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
115
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
116
+ #elif defined(DATA_A_Q8_0)
117
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
118
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
119
+
120
+ const uint ib = idx / 8;
121
+ const uint iqs = idx & 0x07;
122
+
123
+ const float d = float(data_a_packed16[ib].d);
124
+ const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
125
+ const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
126
+ const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
127
+
128
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
129
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
130
+ #elif defined(DATA_A_Q2_K)
131
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
132
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
133
+
134
+ const uint ib = idx / 128; // 2 values per idx
135
+ const uint iqs = idx % 128; // 0..127
136
+
137
+ const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
138
+ const uint scalesi = iqs / 8; // 0..15
139
+ const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
140
+
141
+ const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
142
+ const uint scales = data_a[ib].scales[scalesi];
143
+ const vec2 d = vec2(data_a[ib].d);
144
+
145
+ const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
146
+
147
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
148
+ #elif defined(DATA_A_Q3_K)
149
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
150
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
151
+
152
+ const uint ib = idx / 128; // 2 values per idx
153
+ const uint iqs = idx % 128; // 0..127
154
+
155
+ const uint n = iqs / 64; // 0,1
156
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
157
+ const uint hmi = (iqs % 16) * 2; // 0,2,4..30
158
+ const uint j = (iqs % 64) / 4; // 0..3
159
+ const uint is = iqs / 8; // 0..15
160
+ const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
161
+ const uint qsshift = halfsplit * 2; // 0,2,4,6
162
+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
163
+
164
+ const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
165
+ | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
166
+ const float dl = float(data_a[ib].d) * float(us - 32);
167
+
168
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
169
+ dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
170
+ #elif defined(DATA_A_Q4_K)
171
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
172
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
173
+
174
+ const uint ib = idx / 128; // 2 values per idx
175
+ const uint iqs = idx % 128; // 0..127
176
+
177
+ const uint n = iqs / 32; // 0,1,2,3
178
+ const uint b = (iqs % 32) / 16; // 0,1
179
+ const uint is = 2 * n + b; // 0..7
180
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
181
+
182
+ const vec2 loadd = vec2(data_a[ib].d);
183
+
184
+ const uint scidx0 = (is < 4) ? is : (is + 4);
185
+ const uint scidx1 = (is < 4) ? is : (is - 4);
186
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
187
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
188
+ const uint mbidx0 = is + 4;
189
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
190
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
191
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
192
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
193
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
194
+
195
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
196
+ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
197
+
198
+ const float d = loadd.x * sc;
199
+ const float m = -loadd.y * mbyte;
200
+
201
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
202
+ fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
203
+ #elif defined(DATA_A_Q5_K)
204
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
205
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
206
+
207
+ const uint ib = idx / 128; // 2 values per idx
208
+ const uint iqs = idx % 128; // 0..127
209
+
210
+ const uint n = iqs / 32; // 0,1,2,3
211
+ const uint b = (iqs % 32) / 16; // 0,1
212
+ const uint is = 2 * n + b; // 0..7
213
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
214
+ const uint qhi = (iqs % 16) * 2; // 0,2,4..30
215
+
216
+ const uint8_t hm = uint8_t(1 << (iqs / 16));
217
+
218
+ const vec2 loadd = vec2(data_a[ib].d);
219
+
220
+ const uint scidx0 = (is < 4) ? is : (is + 4);
221
+ const uint scidx1 = (is < 4) ? is : (is - 4);
222
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
223
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
224
+ const uint mbidx0 = is + 4;
225
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
226
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
227
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
228
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
229
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
230
+
231
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
232
+ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
233
+
234
+ const float d = loadd.x * sc;
235
+ const float m = -loadd.y * mbyte;
236
+
237
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
238
+ fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
239
+ #elif defined(DATA_A_Q6_K)
240
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
241
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
242
+
243
+ const uint ib = idx / 128; // 2 values per idx
244
+ const uint iqs = idx % 128; // 0..127
245
+
246
+ const uint n = iqs / 64; // 0,1
247
+ const uint b = (iqs % 64) / 32; // 0,1
248
+ const uint is_b = (iqs % 16) / 8; // 0,1
249
+ const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
250
+ const uint is = 8 * n + qhshift + is_b; // 0..15
251
+ const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
252
+ const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
253
+
254
+ const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
255
+
256
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
257
+ dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
258
+ #elif defined(DATA_A_IQ1_S)
259
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
260
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
261
+
262
+ const uint ib = idx / 32; // 8 values per idx
263
+ const uint ib32 = (idx % 32) / 4; // 0..7
264
+ const uint ib8 = idx % 32;
265
+
266
+ const float d = float(data_a[ib].d);
267
+ const uint qh = data_a[ib].qh[ib32];
268
+ const uint qs = data_a[ib].qs[ib8];
269
+ const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
270
+ const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
271
+ const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
272
+
273
+ [[unroll]] for (int k = 0; k < 4; ++k) {
274
+ buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
275
+ dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
276
+ }
277
+ #elif defined(DATA_A_IQ1_M)
278
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
279
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
280
+
281
+ const uint ib = idx / 32; // 8 values per idx
282
+ const uint ib8 = idx % 32;
283
+ const uint ib16 = ib8 / 2;
284
+
285
+ const uint16_t[4] scales = data_a[ib].scales;
286
+ const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
287
+ const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
288
+ const uint sc = scales[ib8 / 8];
289
+ const uint qs = data_a[ib].qs[ib8];
290
+ const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
291
+ const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
292
+ const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
293
+ const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
294
+
295
+ [[unroll]] for (int k = 0; k < 4; ++k) {
296
+ buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
297
+ dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
298
+ }
299
+ #elif defined(DATA_A_IQ2_XXS)
300
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
301
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
302
+
303
+ const uint ib = idx / 32; // 8 values per idx
304
+ const uint ib32 = (idx % 32) / 4; // 0..7
305
+ const uint ib8 = idx % 4;
306
+
307
+ const float d = float(data_a[ib].d);
308
+ const uint qs = data_a[ib].qs[8 * ib32 + ib8];
309
+ const uint signs = pack32(u8vec4(
310
+ data_a[ib].qs[8*ib32 + 4],
311
+ data_a[ib].qs[8*ib32 + 5],
312
+ data_a[ib].qs[8*ib32 + 6],
313
+ data_a[ib].qs[8*ib32 + 7]
314
+ ));
315
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
316
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
317
+ const uint sign = sign7 | (bitCount(sign7) << 7);
318
+ const uvec2 grid = iq2xxs_grid[qs];
319
+ const vec4 grid0 = vec4(unpack8(grid.x));
320
+ const vec4 grid1 = vec4(unpack8(grid.y));
321
+
322
+ buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
323
+ (sign & 2) != 0 ? -grid0.y : grid0.y);
324
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
325
+ (sign & 8) != 0 ? -grid0.w : grid0.w);
326
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
327
+ (sign & 32) != 0 ? -grid1.y : grid1.y);
328
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
329
+ (sign & 128) != 0 ? -grid1.w : grid1.w);
330
+ #elif defined(DATA_A_IQ2_XS)
331
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
332
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
333
+
334
+ const uint ib = idx / 32; // 8 values per idx
335
+ const uint ib32 = (idx % 32) / 4; // 0..7
336
+ const uint ib8 = idx % 4; // 0..3
337
+
338
+ const float d = float(data_a[ib].d);
339
+ const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
340
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
341
+ const uint qs = data_a[ib].qs[4 * ib32 + ib8];
342
+ const uint sign7 = qs >> 9;
343
+ const uint sign = sign7 | (bitCount(sign7) << 7);
344
+ const uvec2 grid = iq2xs_grid[qs & 511];
345
+ const vec4 grid0 = vec4(unpack8(grid.x));
346
+ const vec4 grid1 = vec4(unpack8(grid.y));
347
+
348
+ buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
349
+ (sign & 2) != 0 ? -grid0.y : grid0.y);
350
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
351
+ (sign & 8) != 0 ? -grid0.w : grid0.w);
352
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
353
+ (sign & 32) != 0 ? -grid1.y : grid1.y);
354
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
355
+ (sign & 128) != 0 ? -grid1.w : grid1.w);
356
+ #elif defined(DATA_A_IQ2_S)
357
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
358
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
359
+
360
+ const uint ib = idx / 32; // 8 values per idx
361
+ const uint ib8 = idx % 32; // 0..31
362
+ const uint ib32 = ib8 / 4; // 0..7
363
+
364
+ const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
365
+ const uint qs = data_a[ib].qs[ib8];
366
+ const uint qh = data_a[ib].qh[ib32];
367
+ const uint qhshift = 2 * (ib8 % 4);
368
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
369
+
370
+ const float d = float(data_a[ib].d);
371
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
372
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
373
+ const vec4 grid0 = vec4(unpack8(grid.x));
374
+ const vec4 grid1 = vec4(unpack8(grid.y));
375
+
376
+ buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
377
+ (sign & 2) != 0 ? -grid0.y : grid0.y);
378
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
379
+ (sign & 8) != 0 ? -grid0.w : grid0.w);
380
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
381
+ (sign & 32) != 0 ? -grid1.y : grid1.y);
382
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
383
+ (sign & 128) != 0 ? -grid1.w : grid1.w);
384
+ #elif defined(DATA_A_IQ3_XXS)
385
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
386
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
387
+
388
+ const uint ib = idx / 64; // 4 values per idx
389
+ const uint iqs = idx % 64; // 0..63
390
+ const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
391
+
392
+ const float d = float(data_a[ib].d);
393
+ const uint qs = data_a[ib].qs[iqs];
394
+ const uint signs = pack32(u8vec4(
395
+ data_a[ib].qs[is+0],
396
+ data_a[ib].qs[is+1],
397
+ data_a[ib].qs[is+2],
398
+ data_a[ib].qs[is+3]
399
+ ));
400
+ const float db = d * 0.5 * (0.5 + (signs >> 28));
401
+ const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
402
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
403
+ const uint grid = iq3xxs_grid[qs];
404
+ const vec4 v = db * vec4(unpack8(grid));
405
+
406
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
407
+ (sign & 2) != 0 ? -v.y : v.y);
408
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
409
+ (sign & 8) != 0 ? -v.w : v.w);
410
+ #elif defined(DATA_A_IQ3_S)
411
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
412
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
413
+
414
+ const uint ib = idx / 64; // 4 values per idx
415
+ const uint iqs = idx % 64; // 0..63
416
+ const uint iqh = iqs / 8;
417
+
418
+ const float d = float(data_a[ib].d);
419
+ const uint qs = data_a[ib].qs[iqs];
420
+ const uint qh = data_a[ib].qh[iqh];
421
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
422
+ const uint scale = data_a[ib].scales[iqs / 16];
423
+ const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
424
+ const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
425
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
426
+ const vec4 v = db * vec4(unpack8(grid));
427
+
428
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
429
+ (sign & 2) != 0 ? -v.y : v.y);
430
+ buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
431
+ (sign & 8) != 0 ? -v.w : v.w);
432
+ #elif defined(DATA_A_IQ4_XS)
433
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
434
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
435
+
436
+ const uint ib = idx / 128; // 2 values per idx
437
+ const uint ib32 = (idx % 128) / 16; // 0..7
438
+ const uint iq = 16 * ib32 + 2 * (idx % 8);
439
+
440
+ const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
441
+ const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
442
+ const uint qshift = (idx & 8) >> 1;
443
+ u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
444
+ qs = (qs >> qshift) & uint8_t(0xF);
445
+
446
+ const float d = float(data_a[ib].d);
447
+ const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
448
+
449
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
450
+ #elif defined(DATA_A_IQ4_NL)
451
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
452
+ const uint buf_idx = col * SHMEM_STRIDE + row;
453
+
454
+ const uint ib = idx / 8;
455
+ const uint iqs = idx & 0x07;
456
+
457
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
458
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
459
+
460
+ buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
461
+ kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
462
+ buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
463
+ kvalues_iq4nl[vui >> 12]);
464
+ #elif defined(DATA_A_MXFP4)
465
+ const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
466
+ const uint buf_idx = col * SHMEM_STRIDE + row;
467
+
468
+ const uint ib = idx / 8;
469
+ const uint iqs = (idx & 0x07) * 2;
470
+
471
+ const float d = e8m0_to_fp32(data_a[ib].e);
472
+ const uint vui = uint(data_a[ib].qs[iqs]);
473
+ const uint vui2 = uint(data_a[ib].qs[iqs+1]);
474
+
475
+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
476
+ kvalues_mxfp4[vui2 & 0xF] * d);
477
+ buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
478
+ kvalues_mxfp4[vui2 >> 4] * d);
479
+ #endif
480
+ }
481
+
482
+ #if !defined(MUL_MAT_ID)
483
+ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
484
+ #if LOAD_VEC_B == 8
485
+ // Not supported for b_type bf16 because bf16mat2x4 does not exist
486
+ const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
487
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
488
+ FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
489
+ buf_b[buf_idx + 0] = bb[0].xy;
490
+ buf_b[buf_idx + 1] = bb[0].zw;
491
+ buf_b[buf_idx + 2] = bb[1].xy;
492
+ buf_b[buf_idx + 3] = bb[1].zw;
493
+ #elif LOAD_VEC_B == 4
494
+ const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
495
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
496
+ #if defined(DATA_B_BF16)
497
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
498
+ #else
499
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
500
+ #endif
501
+ buf_b[buf_idx + 0] = bb.xy;
502
+ buf_b[buf_idx + 1] = bb.zw;
503
+ #else // LOAD_VEC_BATCH_B == 2
504
+ const uint idx = pos_b + col * p.stride_b + row * 2;
505
+ const uint buf_idx = col * SHMEM_STRIDE + row;
506
+ if (idx_n < p.N && block + row * 2 + 1 < end_k) {
507
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
508
+ TO_FLOAT_TYPE(data_b[idx + 1]));
509
+ } else if (idx_n < p.N && block + row * 2 < end_k) {
510
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
511
+ } else {
512
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
513
+ }
514
+ #endif
515
+ }
516
+ #else
517
+ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
518
+ #if LOAD_VEC_B == 8
519
+ // Not supported for b_type bf16 because bf16mat2x4 does not exist
520
+ const u16vec2 row_idx = row_ids[col];
521
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
522
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
523
+ FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
524
+ buf_b[buf_idx + 0] = bb[0].xy;
525
+ buf_b[buf_idx + 1] = bb[0].zw;
526
+ buf_b[buf_idx + 2] = bb[1].xy;
527
+ buf_b[buf_idx + 3] = bb[1].zw;
528
+ #elif LOAD_VEC_B == 4
529
+ const u16vec2 row_idx = row_ids[col];
530
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
531
+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
532
+ #if defined(DATA_B_BF16)
533
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
534
+ #else
535
+ FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
536
+ #endif
537
+ buf_b[buf_idx + 0] = bb.xy;
538
+ buf_b[buf_idx + 1] = bb.zw;
539
+ #else // LOAD_VEC_BATCH_B == 2
540
+ const uint row_i = ic * BN + col;
541
+ const uint buf_idx = col * SHMEM_STRIDE + row;
542
+ if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
543
+ const u16vec2 row_idx = row_ids[col];
544
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
545
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
546
+ TO_FLOAT_TYPE(data_b[idx + 1]));
547
+ } else if (row_i < _ne1 && block + row * 2 < end_k) {
548
+ const u16vec2 row_idx = row_ids[col];
549
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
550
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
551
+ } else {
552
+ buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
553
+ }
554
+ #endif
555
+ }
556
+ #endif
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
28
28
  #if defined(A_TYPE_PACKED32)
29
29
  layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
30
30
  #endif
31
- layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
31
+ layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
32
32
  layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
33
33
 
34
34
  #ifdef MUL_MAT_ID
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
98
98
  #endif
99
99
 
100
100
  #define LOAD_VEC_A (4 * QUANT_R)
101
- #define LOAD_VEC_B 4
101
+ #define LOAD_VEC_B 16
102
102
 
103
103
  #ifdef MUL_MAT_ID
104
104
  shared u16vec2 row_ids[4096];
@@ -270,15 +270,22 @@ void main() {
270
270
  const uint iqs = idx & 0x7;
271
271
  #else
272
272
  const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
273
+ const uint ib_outer = ib / 4;
274
+ const uint ib_inner = ib % 4;
275
+
273
276
  const uint iqs = loadr_b;
274
277
  #endif
275
278
 
276
279
  const uint buf_ib = loadc_b + l;
277
280
 
278
281
  if (iqs == 0) {
279
- buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282
+ buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280
283
  }
281
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284
+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
285
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
286
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
287
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
288
+ buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
282
289
  }
283
290
 
284
291
  barrier();
@@ -349,7 +356,7 @@ void main() {
349
356
  cache_b_qs[cc * (BK / 4) + idx_k]);
350
357
  }
351
358
 
352
- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
359
+ sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
353
360
  }
354
361
  }
355
362
  }
@@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
16
16
  (vui >> 4) & 0x0F0F0F0F);
17
17
  }
18
18
 
19
- ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
20
- return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
19
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
20
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
21
21
  }
22
22
  #endif
23
23
 
@@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
29
29
  (vui >> 4) & 0x0F0F0F0F);
30
30
  }
31
31
 
32
- ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
33
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
32
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
33
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
34
34
  }
35
35
  #endif
36
36
 
@@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
50
50
  return i32vec2(v0, v1);
51
51
  }
52
52
 
53
- ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
54
- return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
53
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
54
+ return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
55
55
  }
56
56
  #endif
57
57
 
@@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
69
69
  return i32vec2(v0, v1);
70
70
  }
71
71
 
72
- ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
73
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
72
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
73
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
74
74
  }
75
75
  #endif
76
76
 
@@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
81
81
  data_a[ib].qs[iqs * 2 + 1]));
82
82
  }
83
83
 
84
- ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
84
+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
85
85
  return ACC_TYPE(float(q_sum) * da * dsb.x);
86
86
  }
87
87
  #endif
@@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
92
92
  }
93
93
  #endif
94
94
 
95
+ #if defined(DATA_A_MXFP4)
96
+ FLOAT_TYPE get_d(uint ib) {
97
+ return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
98
+ }
99
+ #endif
100
+
95
101
  #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
96
102
  FLOAT_TYPE_VEC2 get_dm(uint ib) {
97
103
  return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);