whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -24,6 +24,31 @@
24
24
 
25
25
  #define UNUSED GGML_UNUSED
26
26
 
27
+ #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
28
+ static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
29
+ int16x8_t * out_mins,
30
+ int8_t * out_scales) {
31
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
32
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
33
+ constexpr uint32_t kmask3 = 0x03030303;
34
+ constexpr uint8_t scales_size = 12;
35
+
36
+ uint32_t sm[3];
37
+ memcpy(sm, scales_in, scales_size);
38
+
39
+ const uint32_t mins_0_3 = sm[1] & kmask1;
40
+ const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
41
+ const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
42
+
43
+ *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
44
+
45
+ uint32_t scales_u32[2];
46
+ scales_u32[0] = sm[0] & kmask1;
47
+ scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
48
+ memcpy(out_scales, scales_u32, 8);
49
+ }
50
+ #endif
51
+
27
52
  void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
28
53
  assert(QK8_0 == 32);
29
54
  assert(k % QK8_0 == 0);
@@ -86,35 +111,9 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
86
111
  }
87
112
  }
88
113
  #else
89
- // scalar
90
- const int blck_size_interleave = 4;
91
- float srcv[4][QK8_0];
92
- float id[4];
93
-
94
- for (int i = 0; i < nb; i++) {
95
- for (int row_iter = 0; row_iter < 4; row_iter++) {
96
- float amax = 0.0f; // absolute max
97
-
98
- for (int j = 0; j < QK8_0; j++) {
99
- srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
100
- amax = MAX(amax, fabsf(srcv[row_iter][j]));
101
- }
102
-
103
- const float d = amax / ((1 << 7) - 1);
104
- id[row_iter] = d ? 1.0f / d : 0.0f;
105
-
106
- y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
107
- }
108
-
109
- for (int j = 0; j < QK8_0 * 4; j++) {
110
- int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
111
- int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
112
- src_offset += (j % blck_size_interleave);
113
-
114
- float x0 = srcv[src_id][src_offset] * id[src_id];
115
- y[i].qs[j] = roundf(x0);
116
- }
117
- }
114
+ UNUSED(nb);
115
+ UNUSED(y);
116
+ ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
118
117
  #endif
119
118
  }
120
119
 
@@ -205,35 +204,9 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
205
204
  }
206
205
 
207
206
  #else
208
- // scalar
209
- const int blck_size_interleave = 8;
210
- float srcv[4][QK8_0];
211
- float id[4];
212
-
213
- for (int i = 0; i < nb; i++) {
214
- for (int row_iter = 0; row_iter < 4; row_iter++) {
215
- float amax = 0.0f; // absolute max
216
-
217
- for (int j = 0; j < QK8_0; j++) {
218
- srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
219
- amax = MAX(amax, fabsf(srcv[row_iter][j]));
220
- }
221
-
222
- const float d = amax / ((1 << 7) - 1);
223
- id[row_iter] = d ? 1.0f / d : 0.0f;
224
-
225
- y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
226
- }
227
-
228
- for (int j = 0; j < QK8_0 * 4; j++) {
229
- int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
230
- int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
231
- src_offset += (j % blck_size_interleave);
232
-
233
- float x0 = srcv[src_id][src_offset] * id[src_id];
234
- y[i].qs[j] = roundf(x0);
235
- }
236
- }
207
+ UNUSED(nb);
208
+ UNUSED(y);
209
+ ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);
237
210
  #endif
238
211
  }
239
212
 
@@ -295,29 +268,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
295
268
  }
296
269
  return;
297
270
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
298
- float sumf[4];
299
- int sumi;
300
-
301
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
302
- for (int x = 0; x < nc / ncols_interleaved; x++) {
303
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
304
-
305
- for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
306
- for (int l = 0; l < nb; l++) {
307
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
308
- for (int j = 0; j < ncols_interleaved; j++) {
309
- sumi = 0;
310
- for (int i = 0; i < blocklen; ++i) {
311
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
312
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
313
- sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
314
- }
315
- sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
316
- }
317
- }
318
- }
319
- for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
320
- }
271
+ ggml_gemv_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
321
272
  }
322
273
 
323
274
  void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -383,29 +334,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
383
334
  }
384
335
  return;
385
336
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
386
- float sumf[4];
387
- int sumi;
388
-
389
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
390
- for (int x = 0; x < nc / ncols_interleaved; x++) {
391
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
392
-
393
- for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
394
- for (int l = 0; l < nb; l++) {
395
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
396
- for (int j = 0; j < ncols_interleaved; j++) {
397
- sumi = 0;
398
- for (int i = 0; i < blocklen; ++i) {
399
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
400
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
401
- sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
402
- }
403
- sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
404
- }
405
- }
406
- }
407
- for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
408
- }
337
+ ggml_gemv_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
409
338
  }
410
339
 
411
340
  void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -497,31 +426,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
497
426
  #endif // #if defined(__ARM_FEATURE_SVE)
498
427
 
499
428
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
500
- {
501
- float sumf[8];
502
- int sumi;
503
-
504
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
505
- for (int x = 0; x < nc / ncols_interleaved; x++) {
506
- const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
507
-
508
- for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
509
- for (int l = 0; l < nb; l++) {
510
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
511
- for (int j = 0; j < ncols_interleaved; j++) {
512
- sumi = 0;
513
- for (int i = 0; i < blocklen; ++i) {
514
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
515
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
516
- sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
517
- }
518
- sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
519
- }
520
- }
521
- }
522
- for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
523
- }
524
- }
429
+ ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
525
430
  }
526
431
 
527
432
  void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -591,31 +496,421 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
591
496
  }
592
497
  return;
593
498
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
594
- {
595
- float sumf[4];
596
- int sumi;
499
+ ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
500
+ }
597
501
 
598
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
599
- for (int x = 0; x < nc / ncols_interleaved; x++) {
600
- const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
502
+ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
503
+ constexpr int qk = QK_K;
504
+ const int nb = n / qk;
601
505
 
602
- for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
603
- for (int l = 0; l < nb; l++) {
604
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
605
- for (int j = 0; j < ncols_interleaved; j++) {
606
- sumi = 0;
607
- for (int i = 0; i < blocklen; ++i) {
608
- const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
609
- const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
610
- sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
611
- }
612
- sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
506
+ constexpr int ncols_interleaved = 8;
507
+ constexpr int blocklen = 8;
508
+
509
+ assert(n % qk == 0);
510
+ assert(nc % ncols_interleaved == 0);
511
+
512
+ UNUSED(nb);
513
+ UNUSED(ncols_interleaved);
514
+ UNUSED(blocklen);
515
+
516
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
517
+ constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
518
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
519
+
520
+ // 1x8 tile = 2 x 4
521
+ float32x4_t acc_f32[col_groups];
522
+
523
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
524
+
525
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
526
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
527
+
528
+ for (int i = 0; i < col_groups; i++) {
529
+ acc_f32[i] = vdupq_n_f32(0);
530
+ }
531
+
532
+ for (int b = 0; b < nb; b++) {
533
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
534
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
535
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
536
+ float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
537
+ float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
538
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
539
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
540
+ float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
541
+ float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
542
+
543
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
544
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
545
+ int32x4_t acc_lo[col_groups];
546
+ int32x4_t acc_hi[col_groups];
547
+
548
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
549
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
550
+ int16_t bsums_arr[8];
551
+ vst1q_s16(bsums_arr, bsums);
552
+ for (int sb = 0; sb < QK_K / 64; sb++) {
553
+ for (int i = 0; i < col_groups; i++) {
554
+ acc_lo[i] = vdupq_n_s32(0);
555
+ acc_hi[i] = vdupq_n_s32(0);
556
+ }
557
+ // Need scales for the low and high nibbles
558
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
559
+ int16x8_t q4sb_mins[2];
560
+ int16x8_t q4sb_scales[2];
561
+ for (int i = 0; i < 2; i++) {
562
+ int8_t aux_q4sb[8];
563
+ const int offset = sb * 24 + i * 12;
564
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
565
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
566
+ }
567
+
568
+ int8x16_t q8_qs[64 / 16];
569
+ for (int i = 0; i < 64 / 16; i++) {
570
+ q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
571
+ }
572
+
573
+ for (int c = 0; c < col_groups; c++) {
574
+ uint8x16_t q4_cols[8];
575
+ for (int i = 0; i < 8; i++) {
576
+ q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
613
577
  }
578
+
579
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
580
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
581
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
582
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
583
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
584
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
585
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
586
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
587
+
588
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
589
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
590
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
591
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
592
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
593
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
594
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
595
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
614
596
  }
615
- }
616
- for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
597
+
598
+ // Scales
599
+ // row c0123 blk0 and blk1
600
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
601
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
602
+ const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
603
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
604
+ acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
605
+ // row c4567 blk0 and blk1
606
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
607
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
608
+ const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
609
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
610
+ acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
611
+
612
+ // Bias Correction
613
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
614
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
615
+
616
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
617
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
618
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
619
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
620
+ } // for sb
621
+
622
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
623
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
624
+ } // for b
625
+
626
+ int base = x * ncols_interleaved;
627
+ vst1q_f32(s + base, acc_f32[0]);
628
+ vst1q_f32(s + base + 4, acc_f32[1]);
629
+ } // for x
630
+ return;
631
+ #endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
632
+ ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
633
+ }
634
+
635
+ void ggml_gemv_q4_K_8x8_q8_K(int n,
636
+ float * GGML_RESTRICT s,
637
+ size_t bs,
638
+ const void * GGML_RESTRICT vx,
639
+ const void * GGML_RESTRICT vy,
640
+ int nr,
641
+ int nc) {
642
+ constexpr int qk = QK_K;
643
+ const int nb = n / qk;
644
+
645
+ constexpr int ncols_interleaved = 8;
646
+ constexpr int blocklen = 8;
647
+
648
+ assert(n % qk == 0);
649
+ assert(nc % ncols_interleaved == 0);
650
+
651
+ UNUSED(nb);
652
+ UNUSED(ncols_interleaved);
653
+ UNUSED(blocklen);
654
+
655
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
656
+ constexpr int col_pairs = ncols_interleaved / 2;
657
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
658
+
659
+ // 1x8 tile = 2 x 4
660
+ float32x4_t acc_f32[ncols_interleaved / 4];
661
+
662
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
663
+
664
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
665
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
666
+
667
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
668
+ acc_f32[i] = vdupq_n_f32(0);
669
+ }
670
+
671
+ for (int b = 0; b < nb; b++) {
672
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
673
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
674
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
675
+ float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
676
+ float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
677
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
678
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
679
+ float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
680
+ float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
681
+
682
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
683
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
684
+ // 2 sb each iteration
685
+ int32x4_t acc_lo[col_pairs];
686
+ int32x4_t acc_hi[col_pairs];
687
+
688
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
689
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
690
+ int16_t bsums_arr[8];
691
+ vst1q_s16(bsums_arr, bsums);
692
+ for (int sb = 0; sb < QK_K / 64; sb++) {
693
+ for (int i = 0; i < col_pairs; i++) {
694
+ acc_lo[i] = vdupq_n_s32(0);
695
+ acc_hi[i] = vdupq_n_s32(0);
696
+ }
697
+ // Need scales for the low and high nibbles
698
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
699
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
700
+ int16x8_t q4sb_scales[2];
701
+ for (int i = 0; i < 2; i++) {
702
+ int8_t aux_q4sb[8];
703
+ const int offset = sb * 24 + i * 12;
704
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
705
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
706
+ }
707
+
708
+ const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
709
+
710
+ // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
711
+ // but still need the qs to use the low and hi bits from q4
712
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
713
+ int8x16_t q8_qs[8];
714
+ for (int i = 0; i < 8; i++) {
715
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
716
+ }
717
+
718
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
719
+ for (int cp = 0; cp < col_pairs; cp++) {
720
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
721
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
722
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
723
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
724
+
725
+ acc_lo[cp] =
726
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
727
+ acc_lo[cp] =
728
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
729
+ acc_lo[cp] =
730
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
731
+ acc_lo[cp] =
732
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
733
+
734
+ acc_hi[cp] =
735
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
736
+ acc_hi[cp] =
737
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
738
+ acc_hi[cp] =
739
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
740
+ acc_hi[cp] =
741
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
742
+ }
743
+
744
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
745
+ // p = 0 -> 0123 p2 -> 4567
746
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
747
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
748
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
749
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
750
+
751
+ // 0123 or 4567
752
+ float32x4_t sumf_0 =
753
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
754
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
755
+
756
+ float32x4_t sumf_1 =
757
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
758
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
759
+ }
760
+
761
+ // Multiply Acc bsum + mins
762
+ // Each pair of subblocks share the same bsums
763
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
764
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
765
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
766
+
767
+ // cols 0-3 bias
768
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
769
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
770
+
771
+ // cols 4-7 bias
772
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
773
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
774
+ } // for sb
775
+
776
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
777
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
778
+ } // for b
779
+
780
+ int base = x * ncols_interleaved;
781
+ vst1q_f32(s + base, acc_f32[0]);
782
+ vst1q_f32(s + base + 4, acc_f32[1]);
783
+ } // for x
784
+ return;
785
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
786
+ ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
787
+ }
788
+
789
+ void ggml_gemv_q8_0_4x4_q8_0(int n,
790
+ float * GGML_RESTRICT s,
791
+ size_t bs,
792
+ const void * GGML_RESTRICT vx,
793
+ const void * GGML_RESTRICT vy,
794
+ int nr,
795
+ int nc) {
796
+ const int qk = QK8_0;
797
+ const int nb = n / qk;
798
+ const int ncols_interleaved = 4;
799
+ const int blocklen = 4;
800
+
801
+ assert(n % qk == 0);
802
+ assert(nc % ncols_interleaved == 0);
803
+
804
+ UNUSED(nb);
805
+ UNUSED(ncols_interleaved);
806
+ UNUSED(blocklen);
807
+
808
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
809
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
810
+
811
+ for (int c = 0; c < nc; c += ncols_interleaved) {
812
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
813
+ float32x4_t acc = vdupq_n_f32(0);
814
+ for (int b = 0; b < nb; b++) {
815
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
816
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
817
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
818
+
819
+ int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
820
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
821
+
822
+ int32x4_t ret = vdupq_n_s32(0);
823
+
824
+ ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
825
+ ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
826
+ ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
827
+ ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
828
+
829
+ ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
830
+ ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
831
+ ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
832
+ ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
833
+
834
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
835
+ a_ptr++;
836
+ b_ptr++;
617
837
  }
838
+ vst1q_f32(s, acc);
839
+ s += ncols_interleaved;
618
840
  }
841
+ return;
842
+
843
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
844
+ ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
845
+ }
846
+
847
+ void ggml_gemv_q8_0_4x8_q8_0(int n,
848
+ float * GGML_RESTRICT s,
849
+ size_t bs,
850
+ const void * GGML_RESTRICT vx,
851
+ const void * GGML_RESTRICT vy,
852
+ int nr,
853
+ int nc) {
854
+ const int qk = QK8_0;
855
+ const int nb = n / qk;
856
+ const int ncols_interleaved = 4;
857
+ const int blocklen = 8;
858
+
859
+ assert(n % qk == 0);
860
+ assert(nc % ncols_interleaved == 0);
861
+
862
+ UNUSED(nb);
863
+ UNUSED(ncols_interleaved);
864
+ UNUSED(blocklen);
865
+
866
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
867
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
868
+
869
+ for (int c = 0; c < nc; c += ncols_interleaved) {
870
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
871
+ float32x4_t acc = vdupq_n_f32(0);
872
+
873
+ for (int b = 0; b < nb; b++) {
874
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
875
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
876
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
877
+
878
+ int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
879
+ int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
880
+ int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
881
+ int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
882
+ int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
883
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
884
+
885
+ int32x4_t ret0 = vdupq_n_s32(0);
886
+ int32x4_t ret1 = vdupq_n_s32(0);
887
+
888
+ // 0..7
889
+ ret0 = vdotq_s32(ret0, b_low.val[0], a0);
890
+ ret1 = vdotq_s32(ret1, b_low.val[1], a0);
891
+ // 8..15
892
+ ret0 = vdotq_s32(ret0, b_low.val[2], a1);
893
+ ret1 = vdotq_s32(ret1, b_low.val[3], a1);
894
+ // 16..23
895
+ ret0 = vdotq_s32(ret0, b_high.val[0], a2);
896
+ ret1 = vdotq_s32(ret1, b_high.val[1], a2);
897
+ // 24..31
898
+ ret0 = vdotq_s32(ret0, b_high.val[2], a3);
899
+ ret1 = vdotq_s32(ret1, b_high.val[3], a3);
900
+
901
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
902
+
903
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
904
+ a_ptr++;
905
+ b_ptr++;
906
+ }
907
+ vst1q_f32(s, acc);
908
+ s += ncols_interleaved;
909
+ }
910
+ return;
911
+
912
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
913
+ ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
619
914
  }
620
915
 
621
916
  void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1096,40 +1391,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
1096
1391
  );
1097
1392
  return;
1098
1393
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1099
- {
1100
- float sumf[4][4];
1101
- int sumi;
1102
-
1103
- for (int y = 0; y < nr / 4; y++) {
1104
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1105
- for (int x = 0; x < nc / ncols_interleaved; x++) {
1106
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1107
- for (int m = 0; m < 4; m++) {
1108
- for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1109
- }
1110
- for (int l = 0; l < nb; l++) {
1111
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1112
- for (int m = 0; m < 4; m++) {
1113
- for (int j = 0; j < ncols_interleaved; j++) {
1114
- sumi = 0;
1115
- for (int i = 0; i < blocklen; ++i) {
1116
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1117
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1118
- sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1119
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1120
- }
1121
- sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1122
- }
1123
- }
1124
- }
1125
- }
1126
- for (int m = 0; m < 4; m++) {
1127
- for (int j = 0; j < ncols_interleaved; j++)
1128
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1129
- }
1130
- }
1131
- }
1132
- }
1394
+ ggml_gemm_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1133
1395
  }
1134
1396
 
1135
1397
  void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1550,38 +1812,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
1550
1812
  );
1551
1813
  return;
1552
1814
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
1553
- float sumf[4][4];
1554
- int sumi;
1555
-
1556
- for (int y = 0; y < nr / 4; y++) {
1557
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1558
- for (int x = 0; x < nc / ncols_interleaved; x++) {
1559
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1560
- for (int m = 0; m < 4; m++) {
1561
- for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1562
- }
1563
- for (int l = 0; l < nb; l++) {
1564
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1565
- for (int m = 0; m < 4; m++) {
1566
- for (int j = 0; j < ncols_interleaved; j++) {
1567
- sumi = 0;
1568
- for (int i = 0; i < blocklen; ++i) {
1569
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
1570
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
1571
- sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1572
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
1573
- }
1574
- sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1575
- }
1576
- }
1577
- }
1578
- }
1579
- for (int m = 0; m < 4; m++) {
1580
- for (int j = 0; j < ncols_interleaved; j++)
1581
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1582
- }
1583
- }
1584
- }
1815
+ ggml_gemm_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
1585
1816
  }
1586
1817
 
1587
1818
  void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -2019,38 +2250,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
2019
2250
  #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2020
2251
 
2021
2252
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
2022
- float sumf[4][8];
2023
- int sumi;
2024
-
2025
- for (int y = 0; y < nr / 4; y++) {
2026
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2027
- for (int x = 0; x < nc / ncols_interleaved; x++) {
2028
- const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
2029
- for (int m = 0; m < 4; m++) {
2030
- for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
2031
- }
2032
- for (int l = 0; l < nb; l++) {
2033
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
2034
- for (int m = 0; m < 4; m++) {
2035
- for (int j = 0; j < ncols_interleaved; j++) {
2036
- sumi = 0;
2037
- for (int i = 0; i < blocklen; ++i) {
2038
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
2039
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
2040
- sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
2041
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
2042
- }
2043
- sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
2044
- }
2045
- }
2046
- }
2047
- }
2048
- for (int m = 0; m < 4; m++) {
2049
- for (int j = 0; j < ncols_interleaved; j++)
2050
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
2051
- }
2052
- }
2053
- }
2253
+ ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2054
2254
  }
2055
2255
 
2056
2256
  void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -2126,38 +2326,570 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
2126
2326
  }
2127
2327
  return;
2128
2328
  #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
2129
- {
2130
- float sumf[4][4];
2131
- int sumi;
2132
-
2133
- for (int y = 0; y < nr / 4; y++) {
2134
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2135
- for (int x = 0; x < nc / ncols_interleaved; x++) {
2136
- const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
2137
- for (int m = 0; m < 4; m++) {
2138
- for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
2329
+ ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2330
+ }
2331
+
2332
+ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
2333
+ constexpr int qk = QK_K;
2334
+ const int nb = n / qk;
2335
+
2336
+ constexpr int ncols_interleaved = 8;
2337
+ constexpr int blocklen = 4;
2338
+
2339
+ assert(n % qk == 0);
2340
+ assert(nr % 4 == 0);
2341
+ assert(nc % ncols_interleaved == 0);
2342
+
2343
+ UNUSED(nb);
2344
+ UNUSED(ncols_interleaved);
2345
+ UNUSED(blocklen);
2346
+
2347
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2348
+ constexpr int q8_k_blocklen = 4;
2349
+ constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
2350
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2351
+
2352
+ // 8 accumulators: 2 row pairs × 4 col pairs
2353
+ float32x4_t acc_f32[acc_size];
2354
+
2355
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
2356
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2357
+
2358
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2359
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2360
+
2361
+ for (int i = 0; i < acc_size; i++) {
2362
+ acc_f32[i] = vdupq_n_f32(0);
2363
+ }
2364
+
2365
+ for (int b = 0; b < nb; b++) {
2366
+ // d4 0 1 2 3, 4 5 6 7
2367
+ float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
2368
+ float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
2369
+ // d8 0 1 2 3
2370
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
2371
+ // mins
2372
+ float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
2373
+ float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
2374
+
2375
+ // Precomputation of scales and mins
2376
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
2377
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
2378
+ float32x4_t sbd_min_0123[q8_k_blocklen];
2379
+ float32x4_t sbd_min_4567[q8_k_blocklen];
2380
+
2381
+ sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
2382
+ sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
2383
+ sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
2384
+ sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
2385
+
2386
+ sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
2387
+ sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
2388
+ sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
2389
+ sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
2390
+
2391
+ sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
2392
+ sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
2393
+ sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
2394
+ sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
2395
+
2396
+ sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
2397
+ sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
2398
+ sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
2399
+ sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
2400
+
2401
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
2402
+ const int16x8_t bsums[q8_k_blocklen] = {
2403
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2404
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2405
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2406
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2407
+ };
2408
+ int16_t bsums_arr[QK_K / 64][8];
2409
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2410
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2411
+ }
2412
+
2413
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
2414
+ int32x4_t bias_acc[acc_size];
2415
+ for (int i = 0; i < acc_size; i++) {
2416
+ bias_acc[i] = vdupq_n_s32(0);
2417
+ }
2418
+
2419
+ for (int sb = 0; sb < QK_K / 64; sb++) {
2420
+ // Int accumulators for qs vecdot (4 row x 2 col quartets)
2421
+ int32x4_t acc_lo[acc_size];
2422
+ int32x4_t acc_hi[acc_size];
2423
+ for (int i = 0; i < acc_size; i++) {
2424
+ acc_lo[i] = vdupq_n_s32(0);
2425
+ acc_hi[i] = vdupq_n_s32(0);
2426
+ }
2427
+ // Need scales for the low and high nibbles
2428
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2429
+ int16x8_t q4sb_scales[2];
2430
+ int16x8_t q4sb_mins[2];
2431
+ for (int i = 0; i < 2; i++) {
2432
+ int8_t aux_q4sb[8];
2433
+ const int offset = sb * 24 + i * 12;
2434
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
2435
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
2436
+ }
2437
+
2438
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
2439
+ for (int k = 0; k < reads_per_sb; k++) {
2440
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
2441
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
2442
+
2443
+ // 0..3 & 32..35
2444
+ const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
2445
+ const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
2446
+
2447
+ const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
2448
+ const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
2449
+
2450
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
2451
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
2452
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
2453
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
2454
+
2455
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
2456
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
2457
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
2458
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
2459
+
2460
+ const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
2461
+ const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
2462
+
2463
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
2464
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
2465
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
2466
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
2467
+
2468
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
2469
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
2470
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
2471
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
2472
+ }
2473
+
2474
+ // Scale and bias application
2475
+ // acc is stored interleaved to match output layout
2476
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
2477
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
2478
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
2479
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
2480
+ for (int row = 0; row < q8_k_blocklen; row++) {
2481
+ // Bias correction
2482
+ // row c0123 blk0 and blk1
2483
+ const float32x4_t sumf_0123 =
2484
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
2485
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
2486
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
2487
+
2488
+ // row c4567 blk0 and blk1
2489
+ const float32x4_t sumf_4567 =
2490
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
2491
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
2492
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
2493
+
2494
+ // Bias
2495
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
2496
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
2497
+
2498
+ // row c0123 blk0 and blk1
2499
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2500
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2501
+
2502
+ // row c4567 blk0 and blk1
2503
+ bias_acc[2 * row + 1] =
2504
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2505
+ bias_acc[2 * row + 1] =
2506
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
2507
+ }
2508
+ } // for sb
2509
+
2510
+ for (int row = 0; row < q8_k_blocklen; row++) {
2511
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
2512
+ acc_f32[2 * row + 1] =
2513
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
2514
+ }
2515
+ } // for b
2516
+
2517
+ for (int i = 0; i < q8_k_blocklen; i++) {
2518
+ int row = y * q8_k_blocklen + i;
2519
+ for (int j = 0; j < 2; j++) {
2520
+ int col = x * ncols_interleaved + j * 4;
2521
+ int offset = row * bs + col;
2522
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
2523
+ }
2524
+ }
2525
+ } // for x
2526
+ } // for y
2527
+ return;
2528
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2529
+ ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2530
+ }
2531
+
2532
+ void ggml_gemm_q4_K_8x8_q8_K(int n,
2533
+ float * GGML_RESTRICT s,
2534
+ size_t bs,
2535
+ const void * GGML_RESTRICT vx,
2536
+ const void * GGML_RESTRICT vy,
2537
+ int nr,
2538
+ int nc) {
2539
+ constexpr int qk = QK_K;
2540
+ const int nb = n / qk;
2541
+
2542
+ constexpr int ncols_interleaved = 8;
2543
+ constexpr int blocklen = 8;
2544
+
2545
+ assert(n % qk == 0);
2546
+ assert(nr % 4 == 0);
2547
+ assert(nc % ncols_interleaved == 0);
2548
+
2549
+ UNUSED(nb);
2550
+ UNUSED(ncols_interleaved);
2551
+ UNUSED(blocklen);
2552
+
2553
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2554
+ constexpr int q8_k_blocklen = 4;
2555
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2556
+
2557
+ // 8 accumulators: 2 row pairs × 4 col pairs
2558
+ float32x4_t acc_f32[blocklen];
2559
+
2560
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
2561
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
2562
+
2563
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2564
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
2565
+
2566
+ for (int i = 0; i < blocklen; i++) {
2567
+ acc_f32[i] = vdupq_n_f32(0);
2568
+ }
2569
+
2570
+ for (int b = 0; b < nb; b++) {
2571
+ // bsums pairs belongs to the same q8_k subblock
2572
+ const int16x8_t bsums[4]{
2573
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
2574
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
2575
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
2576
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
2577
+ };
2578
+ int16_t bsums_arr[4][8];
2579
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2580
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
2139
2581
  }
2140
- for (int l = 0; l < nb; l++) {
2141
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
2142
- for (int m = 0; m < 4; m++) {
2143
- for (int j = 0; j < ncols_interleaved; j++) {
2144
- sumi = 0;
2145
- for (int i = 0; i < blocklen; ++i) {
2146
- const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
2147
- const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
2148
- sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
2149
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
2582
+
2583
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
2584
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
2585
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
2586
+ for (int i = 0; i < 8; i++) {
2587
+ acc[i] = vdupq_n_s32(0);
2588
+ bias_acc[i] = vdupq_n_s32(0);
2589
+ }
2590
+
2591
+ for (int sb = 0; sb < QK_K / 64; sb++) {
2592
+ // Need scales for the low and high nibbles
2593
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
2594
+ int8_t q4sb_scales[2][8];
2595
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
2596
+ for (int i = 0; i < 2; i++) {
2597
+ const int offset = sb * 24 + i * 12;
2598
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
2599
+ }
2600
+
2601
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
2602
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
2603
+
2604
+ int8x16_t q8_qs_01[8];
2605
+ int8x16_t q8_qs_23[8];
2606
+
2607
+ // Load 32-byte per row pair, 1 subblock each time
2608
+ for (int i = 0; i < 8; i++) {
2609
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
2610
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
2611
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
2612
+ }
2613
+
2614
+ const int8x16_t q8s[2][8] = {
2615
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
2616
+ q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
2617
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
2618
+ q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
2619
+ };
2620
+
2621
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
2622
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
2623
+ for (int i = 0; i < 4; i++) {
2624
+ sb_acc[i] = vdupq_n_s32(0);
2625
+ }
2626
+
2627
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
2628
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
2629
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
2630
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
2631
+ const int8x16_t q4_nibbles[2][4] = {
2632
+ {
2633
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
2634
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
2635
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
2636
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
2637
+ },
2638
+ {
2639
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
2640
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
2641
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
2642
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
2643
+ }
2644
+ };
2645
+
2646
+ // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
2647
+ // for each of the internal 32 qs subblock (blk)
2648
+ for (int rp = 0; rp < 2; rp++) {
2649
+ for (int blk = 0; blk < 2; blk++) {
2650
+ const int8x16_t * q8 = &q8s[rp][4 * blk];
2651
+ const int8x16_t * q4 = q4_nibbles[blk];
2652
+ int32x4_t acc = sb_acc[2 * rp + blk];
2653
+ // mul add for each qs in the same subblock
2654
+ for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
2655
+ acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
2150
2656
  }
2151
- sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
2657
+ sb_acc[2 * rp + blk] = acc;
2152
2658
  }
2153
2659
  }
2660
+
2661
+ // Scales[i] corresponds to column i
2662
+ const int scale_offset = cp * 2;
2663
+ for (int blk = 0; blk < 2; blk++) {
2664
+ const int32x4_t block_scale = {
2665
+ (int32_t) q4sb_scales[blk][scale_offset],
2666
+ (int32_t) q4sb_scales[blk][scale_offset],
2667
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
2668
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
2669
+ };
2670
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
2671
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
2672
+ }
2673
+ }
2674
+
2675
+ // Multiply Acc bsum + mins
2676
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
2677
+ // Each pair of subblocks share the same bsums
2678
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
2679
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
2680
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
2681
+
2682
+ bias_acc[2 * q8_row] =
2683
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
2684
+ bias_acc[2 * q8_row] =
2685
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2686
+ bias_acc[2 * q8_row + 1] =
2687
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
2688
+ bias_acc[2 * q8_row + 1] =
2689
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
2690
+ }
2691
+ } // for sb
2692
+
2693
+ // Reorder of i8mm output with bias and output layout
2694
+ for (int i = 0; i < 8; i++) {
2695
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
2696
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
2697
+ }
2698
+ int32x4_t reorder_acc[8] = {
2699
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
2700
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
2701
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
2702
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
2703
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
2704
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
2705
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
2706
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
2707
+ };
2708
+
2709
+ for (int i = 0; i < q8_k_blocklen; i++) {
2710
+ for (int j = 0; j < 2; j++) {
2711
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
2712
+ float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
2713
+ const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
2714
+
2715
+ float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
2716
+ const float32x4_t scale = vmulq_f32(q4_d, q8_d);
2717
+
2718
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
2719
+ acc_f32[2 * i + j] =
2720
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
2721
+ }
2722
+ }
2723
+ } // for b
2724
+
2725
+ // With the previous reorder, the tile is already in the correct memory layout.
2726
+ for (int i = 0; i < q8_k_blocklen; i++) {
2727
+ int row = y * q8_k_blocklen + i;
2728
+ for (int j = 0; j < 2; j++) {
2729
+ int col = x * ncols_interleaved + j * 4;
2730
+ int offset = row * bs + col;
2731
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
2732
+ }
2733
+ }
2734
+ } // for x
2735
+ } // for y
2736
+ return;
2737
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2738
+ ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
2739
+ }
2740
+
2741
+
2742
+ void ggml_gemm_q8_0_4x4_q8_0(int n,
2743
+ float * GGML_RESTRICT s,
2744
+ size_t bs,
2745
+ const void * GGML_RESTRICT vx,
2746
+ const void * GGML_RESTRICT vy,
2747
+ int nr,
2748
+ int nc) {
2749
+ const int qk = QK8_0;
2750
+ const int nb = n / qk;
2751
+ const int ncols_interleaved = 4;
2752
+ const int blocklen = 4;
2753
+
2754
+ assert(n % qk == 0);
2755
+ assert(nr % 4 == 0);
2756
+ assert(nc % ncols_interleaved == 0);
2757
+
2758
+ UNUSED(nb);
2759
+ UNUSED(ncols_interleaved);
2760
+ UNUSED(blocklen);
2761
+
2762
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2763
+ for (int y = 0; y < nr / 4; y++) {
2764
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
2765
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
2766
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
2767
+
2768
+ float32x4_t sumf[4];
2769
+ for (int m = 0; m < 4; m++) {
2770
+ sumf[m] = vdupq_n_f32(0);
2771
+ }
2772
+
2773
+ for (int l = 0; l < nb; l++) {
2774
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
2775
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
2776
+
2777
+ int32x4_t sumi_0 = vdupq_n_s32(0);
2778
+ int32x4_t sumi_1 = vdupq_n_s32(0);
2779
+ int32x4_t sumi_2 = vdupq_n_s32(0);
2780
+ int32x4_t sumi_3 = vdupq_n_s32(0);
2781
+
2782
+ for (int k_group = 0; k_group < 8; k_group += 4) {
2783
+ int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
2784
+ int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
2785
+
2786
+ for (int k = 0; k < 4; k++) {
2787
+ sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
2788
+ sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
2789
+ sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
2790
+ sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
2154
2791
  }
2155
2792
  }
2156
- for (int m = 0; m < 4; m++) {
2157
- for (int j = 0; j < ncols_interleaved; j++)
2158
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
2793
+
2794
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
2795
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
2796
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
2797
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
2798
+ }
2799
+
2800
+ for (int m = 0; m < 4; m++) {
2801
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
2802
+ }
2803
+ }
2804
+ }
2805
+ return;
2806
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
2807
+ ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2808
+ }
2809
+
2810
+ void ggml_gemm_q8_0_4x8_q8_0(int n,
2811
+ float * GGML_RESTRICT s,
2812
+ size_t bs,
2813
+ const void * GGML_RESTRICT vx,
2814
+ const void * GGML_RESTRICT vy,
2815
+ int nr,
2816
+ int nc) {
2817
+ const int qk = QK8_0;
2818
+ const int nb = n / qk;
2819
+ const int ncols_interleaved = 4;
2820
+ const int blocklen = 8;
2821
+
2822
+ assert(n % qk == 0);
2823
+ assert(nr % 4 == 0);
2824
+ assert(nc % ncols_interleaved == 0);
2825
+
2826
+ UNUSED(nb);
2827
+ UNUSED(ncols_interleaved);
2828
+ UNUSED(blocklen);
2829
+
2830
+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2831
+ const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
2832
+
2833
+ for (int y = 0; y < nr; y += 4) {
2834
+ const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
2835
+
2836
+ for (int x = 0; x < nc; x += ncols_interleaved) {
2837
+ const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
2838
+ const block_q8_0x4 * a_ptr = a_ptr_base;
2839
+
2840
+ float32x4_t acc_f32[4];
2841
+ for (int i = 0; i < 4; i++) {
2842
+ acc_f32[i] = vdupq_n_f32(0);
2843
+ }
2844
+
2845
+ for (int b = 0; b < nb; b++) {
2846
+ int32x4_t acc[4];
2847
+ for (int i = 0; i < 4; i++) {
2848
+ acc[i] = vdupq_n_s32(0);
2159
2849
  }
2850
+
2851
+ // Process 4 chunks of 8 positions each
2852
+ for (int chunk = 0; chunk < 4; chunk++) {
2853
+ int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
2854
+ int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
2855
+ int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
2856
+ int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
2857
+
2858
+ acc[0] = vmmlaq_s32(acc[0], a01, b01);
2859
+ acc[1] = vmmlaq_s32(acc[1], a01, b23);
2860
+ acc[2] = vmmlaq_s32(acc[2], a23, b01);
2861
+ acc[3] = vmmlaq_s32(acc[3], a23, b23);
2862
+ }
2863
+
2864
+ // Reorder outputs from 2×2 tiles to row-major
2865
+ // acc[0] = [r0c0, r0c1, r1c0, r1c1]
2866
+ // acc[1] = [r0c2, r0c3, r1c2, r1c3]
2867
+ // acc[2] = [r2c0, r2c1, r3c0, r3c1]
2868
+ // acc[3] = [r2c2, r2c3, r3c2, r3c3]
2869
+ int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
2870
+ int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
2871
+ int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
2872
+ int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
2873
+
2874
+ // Scales
2875
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
2876
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
2877
+
2878
+ acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
2879
+ acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
2880
+ acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
2881
+ acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
2882
+
2883
+ a_ptr++;
2884
+ b_ptr++;
2885
+ }
2886
+
2887
+ for (int row = 0; row < 4; row++) {
2888
+ vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
2160
2889
  }
2161
2890
  }
2162
2891
  }
2892
+ return;
2893
+ #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2894
+ ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
2163
2895
  }