whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -1,58 +1,65 @@
1
- #decl(SHMEM_VEC)
1
+ #ifdef VEC
2
+ #define VEC_SIZE 4
3
+ #define SHMEM_TYPE vec4<f16>
4
+ #define DST_TYPE vec4<f32>
5
+ #define SRC0_TYPE vec4<SRC0_INNER_TYPE>
6
+ #define SRC1_TYPE vec4<SRC1_INNER_TYPE>
7
+
2
8
  fn store_shmem(val: vec4<f16>, idx: u32) {
3
9
  shmem[idx] = val.x;
4
10
  shmem[idx + 1] = val.y;
5
11
  shmem[idx + 2] = val.z;
6
12
  shmem[idx + 3] = val.w;
7
13
  }
8
- #enddecl(SHMEM_VEC)
14
+ #endif // VEC
15
+
16
+ #ifdef SCALAR
17
+ #define VEC_SIZE 1
18
+ #define SHMEM_TYPE f16
19
+ #define DST_TYPE f32
20
+ #define SRC0_TYPE SRC0_INNER_TYPE
21
+ #define SRC1_TYPE SRC1_INNER_TYPE
9
22
 
10
- #decl(SHMEM_SCALAR)
11
23
  fn store_shmem(val: f16, idx: u32) {
12
24
  shmem[idx] = val;
13
25
  }
14
- #enddecl(SHMEM_SCALAR)
15
-
16
- #decl(INIT_SRC0_SHMEM_FLOAT)
26
+ #endif // SCALAR
17
27
 
28
+ #ifdef INIT_SRC0_SHMEM_FLOAT
18
29
  fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
19
- for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
30
+ for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
20
31
  let tile_m = elem_idx / TILE_K;
21
32
  let tile_k = elem_idx % TILE_K;
22
33
  let global_m = offset_m + tile_m;
23
34
  let global_k = k_outer + tile_k;
24
35
  let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
25
36
  let src0_val = select( // taking a slight performance hit to avoid oob
26
- {{SRC0_TYPE}}(0.0),
27
- src0[src0_idx/{{VEC_SIZE}}],
37
+ SRC0_TYPE(0.0),
38
+ src0[src0_idx/VEC_SIZE],
28
39
  global_m < params.m && global_k < params.k);
29
- store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
40
+ store_shmem(SHMEM_TYPE(src0_val), elem_idx);
30
41
  }
31
42
  }
43
+ #endif // INIT_SRC0_SHMEM_FLOAT
32
44
 
33
- #enddecl(INIT_SRC0_SHMEM_FLOAT)
34
-
35
- #decl(INIT_SRC1_SHMEM)
36
-
45
+ #ifdef INIT_SRC1_SHMEM_FLOAT
37
46
  fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
38
- for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
47
+ for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
39
48
  let tile_n = elem_idx / TILE_K;
40
49
  let tile_k = elem_idx % TILE_K;
41
50
  let global_n = offset_n + tile_n;
42
51
  let global_k = k_outer + tile_k;
43
52
  let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
44
53
  let src1_val = select(
45
- {{SRC1_TYPE}}(0.0),
46
- src1[src1_idx/{{VEC_SIZE}}],
54
+ SRC1_TYPE(0.0),
55
+ src1[src1_idx/VEC_SIZE],
47
56
  global_n < params.n && global_k < params.k);
48
- store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
57
+ store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
49
58
  }
50
59
  }
60
+ #endif // INIT_SRC1_SHMEM_FLOAT
51
61
 
52
- #enddecl(INIT_SRC1_SHMEM)
53
-
54
- #decl(INIT_SRC0_SHMEM_Q4_0)
55
-
62
+ #ifdef INIT_SRC0_SHMEM_Q4_0
56
63
  const BLOCK_SIZE = 32u;
57
64
  // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
58
65
  override BLOCKS_K = TILE_K/BLOCK_SIZE;
@@ -93,5 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
93
100
  }
94
101
  }
95
102
  }
103
+ #endif // INIT_SRC0_SHMEM_Q4_0
104
+
105
+ #ifdef INIT_SRC0_SHMEM_Q4_1
106
+ const BLOCK_SIZE = 32u;
107
+ // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
108
+ override BLOCKS_K = TILE_K/BLOCK_SIZE;
109
+ const NQ = 16u;
110
+ const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
111
+ const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
112
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
113
+
114
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
115
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
116
+ let blck_idx = i / BLOCK_SIZE;
117
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
118
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
119
+
120
+ let tile_m = blck_idx / BLOCKS_K;
121
+ let global_m = offset_m + tile_m;
122
+ let block_k = blck_idx % BLOCKS_K;
123
+ let global_k = k_outer / BLOCK_SIZE + block_k;
124
+
125
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
126
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
127
+ let scale_idx = src0_idx * F16_PER_BLOCK;
128
+ let d = src0[scale_idx];
129
+ let m = src0[scale_idx + 1u];
130
+
131
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
132
+ let q_0 = src0[scale_idx + 2u + block_offset + j];
133
+ let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
134
+
135
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
136
+ for (var k = 0u; k < 4u; k++) {
137
+ let q_byte = get_byte(q_packed, k);
138
+ let q_lo = f16(q_byte & 0xF) * d + m;
139
+ let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
140
+ shmem[shmem_idx + j * 2 + k] = q_lo;
141
+ shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
142
+ }
143
+ }
144
+ }
145
+ }
146
+ }
147
+ #endif // INIT_SRC0_SHMEM_Q4_1
148
+
149
+ #ifdef INIT_SRC0_SHMEM_Q5_0
150
+ // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
151
+ const BLOCK_SIZE = 32u;
152
+ // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
153
+ // tile_k is defined as 32u, so blocks_k ends up being 1 always
154
+ override BLOCKS_K = TILE_K / BLOCK_SIZE;
155
+ const NQ = 16u;
156
+ const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
157
+ const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
158
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
159
+
160
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
161
+
162
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
163
+ let blck_idx = i / BLOCK_SIZE;
164
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
165
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
166
+
167
+ let tile_m = blck_idx / BLOCKS_K;
168
+ let global_m = offset_m + tile_m;
169
+ let block_k = blck_idx % BLOCKS_K;
170
+ let global_k = k_outer / BLOCK_SIZE + block_k;
171
+
172
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
173
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
174
+ let scale_idx = src0_idx * F16_PER_BLOCK;
96
175
 
97
- #enddecl(INIT_SRC0_SHMEM_Q4_0)
176
+ let d = src0[scale_idx];
177
+ let qh0 = src0[scale_idx + 1u];
178
+ let qh1 = src0[scale_idx + 2u];
179
+ let qh_packed = bitcast<u32>(vec2(qh0, qh1));
180
+
181
+ for (var j = 0u; j < 2; j++) {
182
+ let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
183
+ let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
184
+
185
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
186
+
187
+ let j_adjusted = j + (block_offset / 2u);
188
+
189
+
190
+ for (var k = 0u; k < 4u; k++) {
191
+ let q_byte = get_byte(q_packed, k);
192
+
193
+ let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
194
+ let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
195
+ let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
196
+ let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
197
+
198
+ shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
199
+ shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
200
+ }
201
+ }
202
+ }
203
+ }
204
+ }
205
+ #endif // INIT_SRC0_SHMEM_Q5_0
206
+
207
+ #ifdef INIT_SRC0_SHMEM_Q5_1
208
+ // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
209
+ const BLOCK_SIZE = 32u;
210
+ // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
211
+ // tile_k is defined as 32u, so blocks_k ends up being 1 always
212
+ override BLOCKS_K = TILE_K / BLOCK_SIZE;
213
+ const NQ = 16u;
214
+ const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
215
+ const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
216
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
217
+
218
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
219
+
220
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
221
+ let blck_idx = i / BLOCK_SIZE;
222
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
223
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
224
+
225
+ let tile_m = blck_idx / BLOCKS_K;
226
+ let global_m = offset_m + tile_m;
227
+ let block_k = blck_idx % BLOCKS_K;
228
+ let global_k = k_outer / BLOCK_SIZE + block_k;
229
+
230
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
231
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
232
+ let scale_idx = src0_idx * F16_PER_BLOCK;
233
+
234
+ let d = src0[scale_idx];
235
+ let m = src0[scale_idx + 1u];
236
+ let qh0 = src0[scale_idx + 2u];
237
+ let qh1 = src0[scale_idx + 3u];
238
+ let qh_packed = bitcast<u32>(vec2(qh0, qh1));
239
+
240
+ for (var j = 0u; j < 2; j++) {
241
+
242
+ let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
243
+ let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
244
+
245
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
246
+
247
+ let j_adjusted = j + (block_offset / 2u);
248
+
249
+
250
+ for (var k = 0u; k < 4u; k++) {
251
+ let q_byte = get_byte(q_packed, k);
252
+
253
+ let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
254
+ let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
255
+ let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
256
+ let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
257
+
258
+ shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
259
+ shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
260
+ }
261
+ }
262
+ }
263
+ }
264
+ }
265
+ #endif // INIT_SRC0_SHMEM_Q5_1
266
+
267
+ #ifdef INIT_SRC0_SHMEM_Q8_0
268
+ const BLOCK_SIZE = 32u;
269
+ // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
270
+ override BLOCKS_K = TILE_K/BLOCK_SIZE;
271
+ const NQ = 16u;
272
+ const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
273
+ const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
274
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
275
+
276
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
277
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
278
+ let blck_idx = i / BLOCK_SIZE;
279
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
280
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
281
+
282
+ let tile_m = blck_idx / BLOCKS_K;
283
+ let global_m = offset_m + tile_m;
284
+ let block_k = blck_idx % BLOCKS_K;
285
+ let global_k = k_outer / BLOCK_SIZE + block_k;
286
+
287
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
288
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
289
+ let scale_idx = src0_idx * F16_PER_BLOCK;
290
+ let d = src0[scale_idx];
291
+
292
+ for (var j = 0u; j < F16_PER_THREAD; j+=2) {
293
+ let q_0 = src0[scale_idx + 1u + block_offset + j];
294
+ let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
295
+
296
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
297
+ for (var k = 0u; k < 4u; k++) {
298
+ let q_byte = get_byte_i32(q_packed, k);
299
+
300
+ let q_val = f16(q_byte) * d;
301
+ shmem[shmem_idx + j * 2 + k] = q_val;
302
+ }
303
+ }
304
+ }
305
+ }
306
+ }
307
+ #endif // INIT_SRC0_SHMEM_Q8_0
308
+
309
+ #ifdef INIT_SRC0_SHMEM_Q8_1
310
+ const BLOCK_SIZE = 32u;
311
+ // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
312
+ override BLOCKS_K = TILE_K/BLOCK_SIZE;
313
+ const NQ = 16u;
314
+ const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
315
+ const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
316
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
317
+
318
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
319
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
320
+ let blck_idx = i / BLOCK_SIZE;
321
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
322
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
323
+
324
+ let tile_m = blck_idx / BLOCKS_K;
325
+ let global_m = offset_m + tile_m;
326
+ let block_k = blck_idx % BLOCKS_K;
327
+ let global_k = k_outer / BLOCK_SIZE + block_k;
328
+
329
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
330
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
331
+ let scale_idx = src0_idx * F16_PER_BLOCK;
332
+ let d = src0[scale_idx];
333
+ let m = src0[scale_idx + 1u];
334
+
335
+ for (var j = 0u; j < F16_PER_THREAD; j+=2) {
336
+ let q_0 = src0[scale_idx + 2u + block_offset + j];
337
+ let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
338
+
339
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
340
+ for (var k = 0u; k < 4u; k++) {
341
+ let q_byte = get_byte_i32(q_packed, k);
342
+
343
+ let q_val = f16(q_byte) * d + m;
344
+ shmem[shmem_idx + j * 2 + k] = q_val;
345
+ }
346
+ }
347
+ }
348
+ }
349
+ }
350
+ #endif // INIT_SRC0_SHMEM_Q8_1
351
+
352
+ #ifdef INIT_SRC0_SHMEM_Q2_K
353
+ const BLOCK_SIZE = 256u;
354
+ const F16_PER_BLOCK = 42u;
355
+
356
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
357
+ // Use standard thread layout instead of lane/row_group
358
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
359
+ let tile_m = elem_idx / TILE_K;
360
+ let tile_k = elem_idx % TILE_K;
361
+
362
+ let global_m = offset_m + tile_m;
363
+ let global_k = k_outer + tile_k;
364
+
365
+ if (global_m >= params.m || global_k >= params.k) {
366
+ shmem[elem_idx] = f16(0.0);
367
+ continue;
368
+ }
369
+
370
+ let block_k = global_k / BLOCK_SIZE;
371
+ let k_in_block = global_k % BLOCK_SIZE;
372
+
373
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
374
+ let scale_idx = src0_idx * F16_PER_BLOCK;
375
+
376
+ let d = src0[scale_idx + 40u];
377
+ let dmin = src0[scale_idx + 41u];
378
+
379
+ // Decode the element at position k_in_block
380
+ let block_of_32 = k_in_block / 32u;
381
+ let pos_in_32 = k_in_block % 32u;
382
+
383
+ let q_b_idx = (block_of_32 / 4u) * 32u;
384
+ let shift = (block_of_32 % 4u) * 2u;
385
+ let k = (pos_in_32 / 16u) * 16u;
386
+ let l = pos_in_32 % 16u;
387
+
388
+ let is = k_in_block / 16u;
389
+
390
+ let sc_0 = src0[scale_idx + 2u * (is / 4u)];
391
+ let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
392
+ let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
393
+ let sc = get_byte(sc_packed, is % 4u);
394
+
395
+ let dl = d * f16(sc & 0xFu);
396
+ let ml = dmin * f16(sc >> 4u);
397
+
398
+ let q_idx = q_b_idx + k + l;
399
+ let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
400
+ let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
401
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
402
+ let q_byte = get_byte(q_packed, q_idx % 4u);
403
+ let qs_val = (q_byte >> shift) & 3u;
404
+
405
+ let q_val = f16(qs_val) * dl - ml;
406
+ shmem[elem_idx] = q_val;
407
+ }
408
+ }
409
+ #endif // INIT_SRC0_SHMEM_Q2_K
410
+
411
+ #ifdef INIT_SRC0_SHMEM_Q3_K
412
+ const BLOCK_SIZE = 256u;
413
+ const F16_PER_BLOCK = 55u;
414
+
415
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
416
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
417
+ let tile_m = elem_idx / TILE_K;
418
+ let tile_k = elem_idx % TILE_K;
419
+
420
+ let global_m = offset_m + tile_m;
421
+ let global_k = k_outer + tile_k;
422
+
423
+ if (global_m >= params.m || global_k >= params.k) {
424
+ shmem[elem_idx] = f16(0.0);
425
+ continue;
426
+ }
427
+
428
+ let block_k = global_k / BLOCK_SIZE;
429
+ let k_in_block = global_k % BLOCK_SIZE;
430
+
431
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
432
+ let scale_idx = src0_idx * F16_PER_BLOCK;
433
+
434
+ let d = src0[scale_idx + 54u];
435
+
436
+ // Load and unpack scales
437
+ let kmask1: u32 = 0x03030303u;
438
+ let kmask2: u32 = 0x0f0f0f0fu;
439
+
440
+ var scale_vals: array<u32, 4>;
441
+ for (var i: u32 = 0u; i < 4u; i++) {
442
+ let scale_0 = src0[scale_idx + 48u + (2u*i)];
443
+ let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
444
+ scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
445
+ }
446
+
447
+ var tmp: u32 = scale_vals[2];
448
+ scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
449
+ scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
450
+ scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
451
+ scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
452
+
453
+ // Load hmask and qs arrays
454
+ var hmask_vals: array<u32, 8>;
455
+ for (var i: u32 = 0u; i < 8u; i++) {
456
+ let hmask_0 = src0[scale_idx + (2u*i)];
457
+ let hmask_1 = src0[scale_idx + (2u*i) + 1u];
458
+ hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
459
+ }
460
+
461
+ var qs_vals: array<u32, 16>;
462
+ for (var i: u32 = 0u; i < 16u; i++) {
463
+ let qs_0 = src0[scale_idx + 16u + (2u*i)];
464
+ let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
465
+ qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
466
+ }
467
+
468
+ let half = k_in_block / 128u; // 0 or 1
469
+ let pos_in_half = k_in_block % 128u; // 0-127
470
+ let shift_group = pos_in_half / 32u; // 0-3
471
+ let pos_in_32 = pos_in_half % 32u; // 0-31
472
+ let k_group = pos_in_32 / 16u; // 0 or 1
473
+ let l = pos_in_32 % 16u; // 0-15
474
+
475
+ let q_b_idx = half * 32u; // 0 or 32
476
+ let shift = shift_group * 2u; // 0, 2, 4, 6
477
+ let k = k_group * 16u; // 0 or 16
478
+ let is = k_in_block / 16u; // 0-15
479
+
480
+ // m increments every 32 elements across entire 256 element block
481
+ let m_shift = k_in_block / 32u; // 0-7
482
+ let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
483
+
484
+ let sc = get_byte(scale_vals[is / 4u], is % 4u);
485
+ let dl = d * (f16(sc) - 32.0);
486
+
487
+ let q_idx = q_b_idx + k + l;
488
+ let hm_idx = k + l;
489
+
490
+ let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
491
+ let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
492
+
493
+ let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
494
+ let qs_val = (q_byte >> shift) & 3u;
495
+
496
+ let q_val = (f16(qs_val) - f16(hm)) * dl;
497
+ shmem[elem_idx] = q_val;
498
+ }
499
+ }
500
+
501
+ #endif // INIT_SRC0_SHMEM_Q3_K
502
+
503
+ #ifdef INIT_SRC0_SHMEM_Q4_K
504
+ const BLOCK_SIZE = 256u;
505
+ const F16_PER_BLOCK = 72u;
506
+
507
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
508
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
509
+ let tile_m = elem_idx / TILE_K;
510
+ let tile_k = elem_idx % TILE_K;
511
+
512
+ let global_m = offset_m + tile_m;
513
+ let global_k = k_outer + tile_k;
514
+
515
+ if (global_m >= params.m || global_k >= params.k) {
516
+ shmem[elem_idx] = f16(0.0);
517
+ continue;
518
+ }
519
+
520
+ let block_k = global_k / BLOCK_SIZE;
521
+ let k_in_block = global_k % BLOCK_SIZE;
522
+
523
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
524
+ let scale_idx = src0_idx * F16_PER_BLOCK;
525
+
526
+ let d = src0[scale_idx];
527
+ let dmin = src0[scale_idx + 1u];
528
+
529
+ // Load packed scales
530
+ var scale_vals: array<u32, 3>;
531
+ for (var i: u32 = 0u; i < 3u; i++) {
532
+ let scale_0 = src0[scale_idx + 2u + (2u*i)];
533
+ let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
534
+ scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
535
+ }
536
+
537
+ // Map k_in_block to loop structure:
538
+ // Outer loop over 64-element groups (alternating q_b_idx)
539
+ // Inner loop over 2 shifts per group
540
+ let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
541
+ let pos_in_64 = k_in_block % 64u; // 0-63
542
+ let shift_group = pos_in_64 / 32u; // 0 or 1
543
+ let l = pos_in_64 % 32u; // 0-31
544
+
545
+ let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
546
+ let shift = shift_group * 4u; // 0 or 4
547
+ let is = k_in_block / 32u; // 0-7
548
+
549
+ var sc: u32;
550
+ var mn: u32;
551
+
552
+ if (is < 4u) {
553
+ let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
554
+ let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
555
+ sc = sc_byte & 63u;
556
+ mn = min_byte & 63u;
557
+ } else {
558
+ let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
559
+ let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
560
+ let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
561
+
562
+ sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
563
+ mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
564
+ }
565
+
566
+ let dl = d * f16(sc);
567
+ let ml = dmin * f16(mn);
568
+
569
+ let q_idx = q_b_idx + l;
570
+ let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
571
+ let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
572
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
573
+
574
+ let q_byte = get_byte(q_packed, q_idx % 4u);
575
+ let qs_val = (q_byte >> shift) & 0xFu;
576
+
577
+ let q_val = f16(qs_val) * dl - ml;
578
+ shmem[elem_idx] = q_val;
579
+ }
580
+ }
581
+ #endif // INIT_SRC0_SHMEM_Q4_K
582
+
583
+ #ifdef INIT_SRC0_SHMEM_Q5_K
584
+ const BLOCK_SIZE = 256u;
585
+ const F16_PER_BLOCK = 88u;
586
+
587
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
588
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
589
+ let tile_m = elem_idx / TILE_K;
590
+ let tile_k = elem_idx % TILE_K;
591
+
592
+ let global_m = offset_m + tile_m;
593
+ let global_k = k_outer + tile_k;
594
+
595
+ if (global_m >= params.m || global_k >= params.k) {
596
+ shmem[elem_idx] = f16(0.0);
597
+ continue;
598
+ }
599
+
600
+ let block_k = global_k / BLOCK_SIZE;
601
+ let k_in_block = global_k % BLOCK_SIZE;
602
+
603
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
604
+ let scale_idx = src0_idx * F16_PER_BLOCK;
605
+
606
+ let d = src0[scale_idx];
607
+ let dmin = src0[scale_idx + 1u];
608
+
609
+ // Load packed scales
610
+ var scale_vals: array<u32, 3>;
611
+ for (var i: u32 = 0u; i < 3u; i++) {
612
+ let scale_0 = src0[scale_idx + 2u + (2u*i)];
613
+ let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
614
+ scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
615
+ }
616
+
617
+ // The original loop processes elements in groups of 64
618
+ // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
619
+ // But u increments EVERY 32 elements (after each l loop)
620
+ let group_of_64 = k_in_block / 64u; // 0-3
621
+ let pos_in_64 = k_in_block % 64u; // 0-63
622
+ let shift_group = pos_in_64 / 32u; // 0 or 1
623
+ let l = pos_in_64 % 32u; // 0-31
624
+
625
+ let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
626
+ let shift = shift_group * 4u; // 0 or 4
627
+ let is = k_in_block / 32u; // 0-7
628
+
629
+ // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
630
+ let u_shift = k_in_block / 32u; // 0-7
631
+ let u: u32 = 1u << u_shift;
632
+
633
+ var sc: u32;
634
+ var mn: u32;
635
+
636
+ if (is < 4u) {
637
+ let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
638
+ let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
639
+ sc = sc_byte & 63u;
640
+ mn = min_byte & 63u;
641
+ } else {
642
+ let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
643
+ let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
644
+ let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
645
+
646
+ sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
647
+ mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
648
+ }
649
+
650
+ let dl = d * f16(sc);
651
+ let ml = dmin * f16(mn);
652
+
653
+ let q_idx = q_b_idx + l;
654
+ let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
655
+ let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
656
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
657
+
658
+ let q_byte = get_byte(q_packed, q_idx % 4u);
659
+
660
+ let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
661
+ let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
662
+ let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
663
+
664
+ let qh_byte = get_byte(qh_packed, l % 4u);
665
+
666
+ let qs_val = (q_byte >> shift) & 0xFu;
667
+ let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
668
+
669
+ let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
670
+ shmem[elem_idx] = q_val;
671
+ }
672
+ }
673
+
674
+ #endif // INIT_SRC0_SHMEM_Q5_K
675
+
676
+ #ifdef INIT_SRC0_SHMEM_Q6_K
677
+ const BLOCK_SIZE = 256u;
678
+ const F16_PER_BLOCK = 105u;
679
+
680
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
681
+ for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
682
+ let tile_m = elem_idx / TILE_K;
683
+ let tile_k = elem_idx % TILE_K;
684
+
685
+ let global_m = offset_m + tile_m;
686
+ let global_k = k_outer + tile_k;
687
+
688
+ if (global_m >= params.m || global_k >= params.k) {
689
+ shmem[elem_idx] = f16(0.0);
690
+ continue;
691
+ }
692
+
693
+ let block_k = global_k / BLOCK_SIZE;
694
+ let k_in_block = global_k % BLOCK_SIZE;
695
+
696
+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
697
+ let scale_idx = src0_idx * F16_PER_BLOCK;
698
+
699
+ let half = k_in_block / 128u;
700
+ let pos_in_half = k_in_block % 128u;
701
+ let quarter = pos_in_half / 32u;
702
+ let l = pos_in_half % 32u;
703
+
704
+ let ql_b_idx = half * 64u;
705
+ let qh_b_idx = half * 32u;
706
+ let sc_b_idx = half * 8u;
707
+
708
+ // Load only ql13 word needed
709
+ let ql13_flat = ql_b_idx + l;
710
+ let ql13_word = ql13_flat / 4u;
711
+ let ql13 = bitcast<u32>(vec2(
712
+ src0[scale_idx + 2u * ql13_word],
713
+ src0[scale_idx + 2u * ql13_word + 1u]
714
+ ));
715
+ let ql13_b = get_byte(ql13, ql13_flat % 4u);
716
+
717
+ // Load only ql24 word needed
718
+ let ql24_flat = ql_b_idx + l + 32u;
719
+ let ql24_word = ql24_flat / 4u;
720
+ let ql24 = bitcast<u32>(vec2(
721
+ src0[scale_idx + 2u * ql24_word],
722
+ src0[scale_idx + 2u * ql24_word + 1u]
723
+ ));
724
+ let ql24_b = get_byte(ql24, ql24_flat % 4u);
725
+
726
+ // Load only qh word needed
727
+ let qh_flat = qh_b_idx + l;
728
+ let qh_word = qh_flat / 4u;
729
+ let qh = bitcast<u32>(vec2(
730
+ src0[scale_idx + 64u + 2u * qh_word],
731
+ src0[scale_idx + 64u + 2u * qh_word + 1u]
732
+ ));
733
+ let qh_b = get_byte(qh, qh_flat % 4u);
734
+
735
+ let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
736
+ let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
737
+ let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
738
+ let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
739
+
740
+ // Load only the scale word needed
741
+ let is = l / 16u;
742
+ let sc_idx = sc_b_idx + is + quarter * 2u;
743
+ let sc_word = sc_idx / 4u;
744
+ let sc = bitcast<u32>(vec2(
745
+ src0[scale_idx + 96u + 2u * sc_word],
746
+ src0[scale_idx + 96u + 2u * sc_word + 1u]
747
+ ));
748
+ let sc_val = get_byte_i32(sc, sc_idx % 4u);
749
+
750
+ let d = src0[scale_idx + 104u];
751
+
752
+ var q_val: f16;
753
+ if (quarter == 0u) {
754
+ q_val = q1;
755
+ } else if (quarter == 1u) {
756
+ q_val = q2;
757
+ } else if (quarter == 2u) {
758
+ q_val = q3;
759
+ } else {
760
+ q_val = q4;
761
+ }
762
+
763
+ shmem[elem_idx] = d * f16(sc_val) * q_val;
764
+ }
765
+ }
766
+ #endif // INIT_SRC0_SHMEM_Q6_K