whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
124
124
  }
125
125
  }
126
126
 
127
+
128
+ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
129
+ assert(QK_K == 256);
130
+ assert(k % QK_K == 0);
131
+ const int nb = k / QK_K;
132
+
133
+ block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
134
+
135
+ // scalar
136
+ const int blck_size_interleave = 4;
137
+ float srcv[4][QK_K];
138
+ float iscale[4];
139
+
140
+ for (int i = 0; i < nb; i++) {
141
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
142
+ float amax = 0.0f; // absolute max
143
+ float max = 0;
144
+
145
+ for (int j = 0; j < QK_K; j++) {
146
+ srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
147
+ // Update the maximum value of the corresponding super block
148
+ if(amax < fabsf(srcv[row_iter][j])) {
149
+ amax = fabsf(srcv[row_iter][j]);
150
+ max = srcv[row_iter][j];
151
+ }
152
+ }
153
+
154
+ iscale[row_iter] = amax ? -127.f/max : 0;
155
+
156
+ y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
157
+ }
158
+
159
+ for (int j = 0; j < QK_K / 4; j++) {
160
+ y[i].bsums[j] = 0;
161
+ }
162
+
163
+ // Quants values are interleaved in sequence of four bytes from corresponding super blocks
164
+ // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
165
+ // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
166
+ for (int j = 0; j < QK_K * 4; j++) {
167
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
168
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
169
+ src_offset += (j % blck_size_interleave);
170
+ int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
171
+
172
+ float x0 = srcv[src_id][src_offset] * iscale[src_id];
173
+ y[i].qs[j] = nearest_int(x0);
174
+ y[i].bsums[index] += y[i].qs[j];
175
+ }
176
+ }
177
+ }
178
+
127
179
  void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
128
180
  assert(QK_K == 256);
129
181
  assert(k % QK_K == 0);
@@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
192
244
  ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
193
245
  }
194
246
 
247
+ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
248
+ assert(nrow == 4);
249
+ UNUSED(nrow);
250
+ ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
251
+ }
252
+
195
253
  template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
196
254
  assert(nrow == 4);
197
255
  UNUSED(nrow);
@@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
333
391
  }
334
392
  }
335
393
 
394
+ void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
395
+ const int qk = QK_K;
396
+ const int nb = n / qk;
397
+ const int ncols_interleaved = 8;
398
+ const int blocklen = 4;
399
+ static const uint32_t kmask1 = 0x3f3f3f3f;
400
+ static const uint32_t kmask2 = 0x0f0f0f0f;
401
+ static const uint32_t kmask3 = 0x03030303;
402
+
403
+ assert (n % qk == 0);
404
+ assert (nc % ncols_interleaved == 0);
405
+
406
+ UNUSED(bs);
407
+ UNUSED(nr);
408
+
409
+ float sumf[8];
410
+ float sum_minf[8];
411
+ uint32_t utmp[32];
412
+ int sumi1;
413
+ int sumi2;
414
+ int sumi;
415
+
416
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
417
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
418
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
419
+
420
+ for (int j = 0; j < ncols_interleaved; j++) {
421
+ sumf[j] = 0.0;
422
+ sum_minf[j] = 0.0;
423
+ }
424
+ for (int l = 0; l < nb; l++) {
425
+ for (int sb = 0; sb < 8; sb++) {
426
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
427
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
428
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
429
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
430
+ utmp[sb * 4 + 2] = uaux_0;
431
+ utmp[sb * 4 + 0] &= kmask1;
432
+ }
433
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
434
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
435
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
436
+ for (int j = 0; j < ncols_interleaved; j++) {
437
+ sumi1 = 0;
438
+ sumi2 = 0;
439
+ sumi = 0;
440
+ for (int i = 0; i < blocklen; ++i) {
441
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
442
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
443
+ sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
444
+ sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
445
+ sumi1 = sumi1 * scales_0[j];
446
+ sumi2 = sumi2 * scales_1[j];
447
+ sumi += sumi1 + sumi2;
448
+ }
449
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
450
+ }
451
+ }
452
+ for (int sb = 0; sb < 8; sb++) {
453
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
454
+ for (int j = 0; j < ncols_interleaved; j++) {
455
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
456
+ }
457
+ }
458
+ }
459
+ for (int j = 0; j < ncols_interleaved; j++) {
460
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
461
+ }
462
+ }
463
+ }
464
+
336
465
  void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
337
466
  const int qk = QK_K;
338
467
  const int nb = n / qk;
@@ -563,6 +692,100 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
563
692
  }
564
693
  }
565
694
 
695
+ void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
696
+ float * GGML_RESTRICT s,
697
+ size_t bs,
698
+ const void * GGML_RESTRICT vx,
699
+ const void * GGML_RESTRICT vy,
700
+ int nr,
701
+ int nc) {
702
+ const int qk = QK8_0;
703
+ const int nb = n / qk;
704
+ const int ncols_interleaved = 4;
705
+ const int blocklen = 4;
706
+
707
+ assert(nr == 1);
708
+ assert(n % qk == 0);
709
+ assert(nc % ncols_interleaved == 0);
710
+
711
+ UNUSED(bs);
712
+ UNUSED(nr);
713
+
714
+ float sumf[4];
715
+ int sumi;
716
+
717
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
718
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
719
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
720
+
721
+ for (int j = 0; j < ncols_interleaved; j++) {
722
+ sumf[j] = 0.0;
723
+ }
724
+ for (int l = 0; l < nb; l++) {
725
+ for (int k = 0; k < (qk / blocklen); k++) {
726
+ for (int j = 0; j < ncols_interleaved; j++) {
727
+ sumi = 0;
728
+ for (int i = 0; i < blocklen; ++i) {
729
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
730
+ sumi += v0 * a_ptr[l].qs[k * blocklen + i];
731
+ }
732
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
733
+ }
734
+ }
735
+ }
736
+ for (int j = 0; j < ncols_interleaved; j++) {
737
+ s[x * ncols_interleaved + j] = sumf[j];
738
+ }
739
+ }
740
+ }
741
+
742
+ void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
743
+ float * GGML_RESTRICT s,
744
+ size_t bs,
745
+ const void * GGML_RESTRICT vx,
746
+ const void * GGML_RESTRICT vy,
747
+ int nr,
748
+ int nc) {
749
+ const int qk = QK8_0;
750
+ const int nb = n / qk;
751
+ const int ncols_interleaved = 4;
752
+ const int blocklen = 8;
753
+
754
+ assert(nr == 1);
755
+ assert(n % qk == 0);
756
+ assert(nc % ncols_interleaved == 0);
757
+
758
+ UNUSED(bs);
759
+ UNUSED(nr);
760
+
761
+ float sumf[4];
762
+ int sumi;
763
+
764
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
765
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
766
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
767
+
768
+ for (int j = 0; j < ncols_interleaved; j++) {
769
+ sumf[j] = 0.0;
770
+ }
771
+ for (int l = 0; l < nb; l++) {
772
+ for (int k = 0; k < (qk / blocklen); k++) {
773
+ for (int j = 0; j < ncols_interleaved; j++) {
774
+ sumi = 0;
775
+ for (int i = 0; i < blocklen; ++i) {
776
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
777
+ sumi += v0 * a_ptr[l].qs[k * blocklen + i];
778
+ }
779
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
780
+ }
781
+ }
782
+ }
783
+ for (int j = 0; j < ncols_interleaved; j++) {
784
+ s[x * ncols_interleaved + j] = sumf[j];
785
+ }
786
+ }
787
+ }
788
+
566
789
  void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
567
790
  const int qk = QK8_0;
568
791
  const int nb = n / qk;
@@ -727,6 +950,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
727
950
  }
728
951
  }
729
952
 
953
+ void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
954
+ const int qk = QK_K;
955
+ const int nb = n / qk;
956
+ const int ncols_interleaved = 8;
957
+ const int blocklen = 4;
958
+ static const uint32_t kmask1 = 0x3f3f3f3f;
959
+ static const uint32_t kmask2 = 0x0f0f0f0f;
960
+ static const uint32_t kmask3 = 0x03030303;
961
+
962
+ assert (n % qk == 0);
963
+ assert (nr % 4 == 0);
964
+ assert (nc % ncols_interleaved == 0);
965
+
966
+ UNUSED(nb);
967
+ UNUSED(ncols_interleaved);
968
+ UNUSED(blocklen);
969
+
970
+ float sumf[4][8];
971
+ float sum_minf[4][8];
972
+ uint32_t utmp[32];
973
+ int sumi1;
974
+ int sumi2;
975
+ int sumi;
976
+
977
+ for (int y = 0; y < nr / 4; y++) {
978
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
979
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
980
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
981
+ for (int m = 0; m < 4; m++) {
982
+ for (int j = 0; j < ncols_interleaved; j++) {
983
+ sumf[m][j] = 0.0;
984
+ sum_minf[m][j] = 0.0;
985
+ }
986
+ }
987
+ for (int l = 0; l < nb; l++) {
988
+ for (int sb = 0; sb < 8; sb++) {
989
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
990
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
991
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
992
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
993
+ utmp[sb * 4 + 2] = uaux_0;
994
+ utmp[sb * 4 + 0] &= kmask1;
995
+ }
996
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
997
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
998
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
999
+ for (int m = 0; m < 4; m++) {
1000
+ for (int j = 0; j < ncols_interleaved; j++) {
1001
+ sumi1 = 0;
1002
+ sumi2 = 0;
1003
+ sumi = 0;
1004
+ for (int i = 0; i < blocklen; ++i) {
1005
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
1006
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
1007
+ sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
1008
+ sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
1009
+ sumi1 = sumi1 * scales_0[j];
1010
+ sumi2 = sumi2 * scales_1[j];
1011
+ sumi += sumi1 + sumi2;
1012
+ }
1013
+ sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
1014
+ }
1015
+ }
1016
+ }
1017
+ for (int sb = 0; sb < 8; sb++) {
1018
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
1019
+ for(int m = 0; m < 4; m++) {
1020
+ const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
1021
+ for(int j = 0; j < ncols_interleaved; j++) {
1022
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
1023
+ }
1024
+ }
1025
+ }
1026
+ }
1027
+ for (int m = 0; m < 4; m++) {
1028
+ for (int j = 0; j < ncols_interleaved; j++) {
1029
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
1030
+ }
1031
+ }
1032
+ }
1033
+ }
1034
+ }
1035
+
730
1036
  void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
731
1037
  const int qk = QK_K;
732
1038
  const int nb = n / qk;
@@ -1007,8 +1313,129 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
1007
1313
  }
1008
1314
  }
1009
1315
 
1316
+ void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
1317
+ float * GGML_RESTRICT s,
1318
+ size_t bs,
1319
+ const void * GGML_RESTRICT vx,
1320
+ const void * GGML_RESTRICT vy,
1321
+ int nr,
1322
+ int nc) {
1323
+ const int qk = QK8_0;
1324
+ const int nb = n / qk;
1325
+ const int ncols_interleaved = 4;
1326
+ const int blocklen = 4;
1327
+
1328
+ assert(n % qk == 0);
1329
+ assert(nr % 4 == 0);
1330
+ assert(nc % ncols_interleaved == 0);
1331
+
1332
+ float sumf[4][4];
1333
+ int sumi;
1334
+
1335
+ for (int y = 0; y < nr / 4; y++) {
1336
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1337
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1338
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1339
+ for (int m = 0; m < 4; m++) {
1340
+ for (int j = 0; j < ncols_interleaved; j++) {
1341
+ sumf[m][j] = 0.0;
1342
+ }
1343
+ }
1344
+ for (int l = 0; l < nb; l++) {
1345
+ for (int k = 0; k < (qk / blocklen); k++) {
1346
+ for (int m = 0; m < 4; m++) {
1347
+ for (int j = 0; j < ncols_interleaved; j++) {
1348
+ sumi = 0;
1349
+ for (int i = 0; i < blocklen; ++i) {
1350
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1351
+ sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
1352
+ }
1353
+ sumf[m][j] +=
1354
+ sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1355
+ }
1356
+ }
1357
+ }
1358
+ }
1359
+ for (int m = 0; m < 4; m++) {
1360
+ for (int j = 0; j < ncols_interleaved; j++) {
1361
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1362
+ }
1363
+ }
1364
+ }
1365
+ }
1366
+ }
1367
+
1368
+ void ggml_gemm_q8_0_4x8_q8_0_generic(int n,
1369
+ float * GGML_RESTRICT s,
1370
+ size_t bs,
1371
+ const void * GGML_RESTRICT vx,
1372
+ const void * GGML_RESTRICT vy,
1373
+ int nr,
1374
+ int nc) {
1375
+ const int qk = QK8_0;
1376
+ const int nb = n / qk;
1377
+ const int ncols_interleaved = 4;
1378
+ const int blocklen = 8;
1379
+
1380
+ assert(n % qk == 0);
1381
+ assert(nr % 4 == 0);
1382
+ assert(nc % ncols_interleaved == 0);
1383
+
1384
+ float sumf[4][4];
1385
+ int sumi;
1386
+
1387
+ for (int y = 0; y < nr / 4; y++) {
1388
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1389
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1390
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1391
+ for (int m = 0; m < 4; m++) {
1392
+ for (int j = 0; j < ncols_interleaved; j++) {
1393
+ sumf[m][j] = 0.0;
1394
+ }
1395
+ }
1396
+ for (int l = 0; l < nb; l++) {
1397
+ for (int k = 0; k < (qk / blocklen); k++) {
1398
+ for (int m = 0; m < 4; m++) {
1399
+ for (int j = 0; j < ncols_interleaved; j++) {
1400
+ sumi = 0;
1401
+ for (int i = 0; i < blocklen; ++i) {
1402
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1403
+ sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
1404
+ }
1405
+ sumf[m][j] +=
1406
+ sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1407
+ }
1408
+ }
1409
+ }
1410
+ }
1411
+ for (int m = 0; m < 4; m++) {
1412
+ for (int j = 0; j < ncols_interleaved; j++) {
1413
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1414
+ }
1415
+ }
1416
+ }
1417
+ }
1418
+ }
1419
+
1010
1420
  } // extern "C"
1011
1421
 
1422
+ static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
1423
+ block_q8_0x4 out;
1424
+
1425
+ for (int i = 0; i < 4; i++) {
1426
+ out.d[i] = in[i].d;
1427
+ }
1428
+
1429
+ const int end = QK8_0 * 4 / blck_size_interleave;
1430
+ for (int i = 0; i < end; ++i) {
1431
+ int src_id = i % 4;
1432
+ int src_offset = (i / 4) * blck_size_interleave;
1433
+ int dst_offset = i * blck_size_interleave;
1434
+ memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
1435
+ }
1436
+ return out;
1437
+ }
1438
+
1012
1439
  static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
1013
1440
  block_q4_0x4 out;
1014
1441
 
@@ -1228,9 +1655,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
1228
1655
 
1229
1656
  GGML_UNUSED(data_size);
1230
1657
  }
1658
+
1231
1659
  static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1232
1660
  GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
1233
- GGML_ASSERT(interleave_block == 8);
1661
+ GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
1234
1662
  constexpr int nrows_interleaved = 8;
1235
1663
 
1236
1664
  block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
@@ -1321,6 +1749,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block
1321
1749
  GGML_UNUSED(data_size);
1322
1750
  }
1323
1751
 
1752
+ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t,
1753
+ int interleave_block,
1754
+ const void * GGML_RESTRICT data,
1755
+ size_t data_size) {
1756
+ GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
1757
+ GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
1758
+ constexpr int nrows_interleaved = 4;
1759
+
1760
+ block_q8_0x4 * dst = (block_q8_0x4 *) t->data;
1761
+ const block_q8_0 * src = (const block_q8_0 *) data;
1762
+ block_q8_0 dst_tmp[4];
1763
+ int nrow = ggml_nrows(t);
1764
+ int nblocks = t->ne[0] / QK8_0;
1765
+
1766
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
1767
+
1768
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1769
+ return -1;
1770
+ }
1771
+
1772
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
1773
+ for (int64_t x = 0; x < nblocks; x++) {
1774
+ for (int i = 0; i < nrows_interleaved; i++) {
1775
+ dst_tmp[i] = src[x + i * nblocks];
1776
+ }
1777
+ *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
1778
+ }
1779
+ src += nrows_interleaved * nblocks;
1780
+ }
1781
+ return 0;
1782
+ }
1783
+
1324
1784
  static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
1325
1785
  block_iq4_nlx4 out;
1326
1786
 
@@ -1468,6 +1928,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
1468
1928
  return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
1469
1929
  }
1470
1930
 
1931
+ template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1932
+ return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
1933
+ }
1934
+
1471
1935
  template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1472
1936
  return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
1473
1937
  }
@@ -1485,6 +1949,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
1485
1949
  return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
1486
1950
  }
1487
1951
 
1952
+ template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1953
+ return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
1954
+ }
1955
+
1956
+ template <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1957
+ return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
1958
+ }
1959
+
1488
1960
  // gemv
1489
1961
  template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
1490
1962
  void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1501,6 +1973,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
1501
1973
  ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1502
1974
  }
1503
1975
 
1976
+ template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1977
+ ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
1978
+ }
1979
+
1504
1980
  template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1505
1981
  ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1506
1982
  }
@@ -1517,6 +1993,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
1517
1993
  ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1518
1994
  }
1519
1995
 
1996
+ template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1997
+ ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1998
+ }
1999
+
2000
+ template <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2001
+ ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2002
+ }
2003
+
1520
2004
  // gemm
1521
2005
  template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
1522
2006
  void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -1529,6 +2013,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
1529
2013
  ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1530
2014
  }
1531
2015
 
2016
+ template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2017
+ ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
2018
+ }
2019
+
1532
2020
  template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1533
2021
  ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1534
2022
  }
@@ -1549,6 +2037,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
1549
2037
  ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1550
2038
  }
1551
2039
 
2040
+ template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2041
+ ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2042
+ }
2043
+
2044
+ template <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2045
+ ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2046
+ }
2047
+
1552
2048
  class tensor_traits_base : public ggml::cpu::tensor_traits {
1553
2049
  public:
1554
2050
  virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -1600,6 +2096,55 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1600
2096
  return false;
1601
2097
  }
1602
2098
 
2099
+ void forward_mul_mat_one_chunk(ggml_compute_params * params,
2100
+ ggml_tensor * op,
2101
+ int64_t src0_start,
2102
+ int64_t src0_end,
2103
+ int64_t src1_start,
2104
+ int64_t src1_end) {
2105
+ const ggml_tensor * src0 = op->src[0];
2106
+ const ggml_tensor * src1 = op->src[1];
2107
+ ggml_tensor * dst = op;
2108
+
2109
+ GGML_TENSOR_BINARY_OP_LOCALS
2110
+
2111
+ const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
2112
+
2113
+ GGML_ASSERT(ne03 == 1 && ne13 == 1);
2114
+ GGML_ASSERT(ne12 % ne02 == 0);
2115
+ const int64_t r2 = ne12 / ne02;
2116
+
2117
+ const int64_t i12 = src1_start / ne1;
2118
+ const int64_t i11 = src1_start - i12 * ne1;
2119
+
2120
+ // Determine batch index
2121
+ const int64_t i02 = i12 / r2;
2122
+
2123
+ const int64_t i1 = i11;
2124
+ const int64_t i2 = i12;
2125
+
2126
+ const char * src0_ptr = (const char *) src0->data + i02 * nb02;
2127
+ const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
2128
+ char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
2129
+
2130
+ const int64_t nrows = src1_end - src1_start;
2131
+ const int64_t ncols = src0_end - src0_start;
2132
+
2133
+ GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
2134
+
2135
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
2136
+ if (nrows > 3) {
2137
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
2138
+ src0_ptr + src0_start * nb01, src1_ptr,
2139
+ nrows - (nrows % 4), ncols);
2140
+ }
2141
+ for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
2142
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
2143
+ ne01, src0_ptr + src0_start * nb01,
2144
+ src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
2145
+ }
2146
+ }
2147
+
1603
2148
  void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
1604
2149
  const ggml_tensor * src0 = op->src[0];
1605
2150
  const ggml_tensor * src1 = op->src[1];
@@ -1621,6 +2166,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1621
2166
  GGML_ASSERT(nb1 <= nb2);
1622
2167
  GGML_ASSERT(nb2 <= nb3);
1623
2168
 
2169
+ // TODO: General batched mul mat for 4D tensors
2170
+ // Currently only supports 3D tensors
2171
+ GGML_ASSERT(ne03 == 1);
2172
+ GGML_ASSERT(ne13 == 1);
2173
+ GGML_ASSERT(ne3 == 1);
2174
+
1624
2175
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
1625
2176
 
1626
2177
  GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
@@ -1628,46 +2179,102 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1628
2179
 
1629
2180
  char * wdata = static_cast<char *>(params->wdata);
1630
2181
  const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
2182
+ const size_t nbw2 = nbw1 * ne11;
1631
2183
 
1632
- assert(params->wsize >= nbw1 * ne11);
2184
+ assert(params->wsize >= nbw2 * ne12);
1633
2185
 
1634
2186
  const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
1635
2187
 
1636
- int64_t i11_processed = 0;
1637
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1638
- ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
1639
- }
2188
+ // INFO: Quantization is done in planes to avoid extra complexity in chunking.
2189
+ // Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
2190
+ // the planes are broadcast.
2191
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
2192
+ char * data_ptr = (char *) src1->data + i12 * nb12;
2193
+ char * wdata_ptr = wdata + i12 * nbw2;
1640
2194
 
1641
- i11_processed = ne11 - ne11 % 4;
1642
- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1643
- from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
2195
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
2196
+ ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
2197
+ (void *) (wdata_ptr + i11 * nbw1), 4, ne10);
2198
+ }
2199
+
2200
+ const int64_t i11_processed = ne11 - ne11 % 4;
2201
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
2202
+ from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
2203
+ }
1644
2204
  }
1645
2205
 
1646
- ggml_barrier(params->threadpool);
2206
+ // disable for NUMA
2207
+ const bool disable_chunking = ggml_is_numa();
1647
2208
 
1648
- const void * src1_wdata = params->wdata;
1649
- const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
1650
- int64_t src0_start = (ith * ne01) / nth;
1651
- int64_t src0_end = ((ith + 1) * ne01) / nth;
1652
- src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1653
- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1654
- if (src0_start >= src0_end) {
1655
- return;
2209
+ // 4x chunks per thread
2210
+ const int64_t nr0 = ggml_nrows(op->src[0]);
2211
+
2212
+ int nth_scaled = nth * 4;
2213
+ int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
2214
+ int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
2215
+
2216
+ // src1 is chunked only by full planes.
2217
+ // When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
2218
+ // to route them thorugh GEMV.
2219
+ // nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
2220
+ // to avoid affecting their performance
2221
+ int64_t nchunk1 = ne12;
2222
+
2223
+ // Ensure minimum chunk size to avoid alignment issues with high thread counts
2224
+ // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
2225
+ const int64_t min_chunk_size = NB_COLS;
2226
+ if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
2227
+ nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
1656
2228
  }
1657
2229
 
1658
- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1659
- if (ne11 > 3) {
1660
- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1661
- (float *) ((char *) dst->data) + src0_start, ne01,
1662
- (const char *) src0->data + src0_start * nb01,
1663
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
2230
+ int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
2231
+ // Only increase nchunk0 to nth if it won't make chunks too small
2232
+ if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
2233
+ nchunk0 = nth;
2234
+ dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1664
2235
  }
1665
- for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1666
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1667
- (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1668
- (const char *) src0->data + src0_start * nb01,
1669
- (const char *) src1_wdata + (src1_col_stride * iter), 1,
1670
- src0_end - src0_start);
2236
+
2237
+ // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
2238
+ // This prevents creating too many tiny chunks that could overlap after alignment
2239
+ const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
2240
+ nchunk0 = MIN(nchunk0, max_nchunk);
2241
+
2242
+ if (ith == 0) {
2243
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
2244
+ ggml_threadpool_chunk_set(params->threadpool, nth);
2245
+ }
2246
+
2247
+ ggml_barrier(params->threadpool);
2248
+
2249
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
2250
+ int current_chunk = ith;
2251
+
2252
+ while (current_chunk < nchunk0 * nchunk1) {
2253
+ const int64_t ith0 = current_chunk % nchunk0;
2254
+ const int64_t ith1 = current_chunk / nchunk0;
2255
+
2256
+ int64_t src0_start = dr0 * ith0;
2257
+ int64_t src0_end = MIN(src0_start + dr0, nr0);
2258
+
2259
+ // full-plane range for src1
2260
+ int64_t src1_start = ith1 * ne11;
2261
+ int64_t src1_end = (ith1 + 1) * ne11;
2262
+
2263
+ // Align boundaries to NB_COLS - round up to ensure all data is included
2264
+ // The chunk size limiting above ensures chunks are large enough to prevent overlaps
2265
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
2266
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
2267
+ src0_end = MIN(src0_end, ne01);
2268
+
2269
+ // Make sure current plane is the last one before exiting
2270
+ if (src0_start >= src0_end) {
2271
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
2272
+ continue;
2273
+ }
2274
+
2275
+ forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
2276
+
2277
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
1671
2278
  }
1672
2279
  }
1673
2280
 
@@ -1772,8 +2379,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1772
2379
  int64_t src0_cur_start = (ith * ne01) / nth;
1773
2380
  int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
1774
2381
 
2382
+ // Align boundaries to NB_COLS - round up to ensure all data is included
1775
2383
  src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
1776
2384
  src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
2385
+ if (src0_cur_end > ne01) {
2386
+ src0_cur_end = ne01;
2387
+ }
1777
2388
 
1778
2389
  if (src0_cur_start >= src0_cur_end) {
1779
2390
  return;
@@ -1816,6 +2427,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1816
2427
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1817
2428
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1818
2429
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
2430
+
2431
+ // instance for Q4_K
2432
+ static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
1819
2433
  static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1820
2434
 
1821
2435
  // instance for Q2
@@ -1825,8 +2439,13 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1825
2439
  static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
1826
2440
  static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
1827
2441
 
2442
+ // instance for Q8_0
2443
+ static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
2444
+ static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
2445
+
1828
2446
  if (cur->type == GGML_TYPE_Q4_0) {
1829
- if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
2447
+ if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
2448
+ || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
1830
2449
  if (cur->ne[1] % 8 == 0) {
1831
2450
  return &q4_0_8x8_q8_0;
1832
2451
  }
@@ -1847,6 +2466,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1847
2466
  return &q4_K_8x8_q8_K;
1848
2467
  }
1849
2468
  }
2469
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
2470
+ if (cur->ne[1] % 8 == 0) {
2471
+ return &q4_K_8x8_q8_K;
2472
+ }
2473
+ }
2474
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
2475
+ if (cur->ne[1] % 8 == 0) {
2476
+ return &q4_K_8x4_q8_K;
2477
+ }
2478
+ }
1850
2479
  } else if (cur->type == GGML_TYPE_Q2_K) {
1851
2480
  if (ggml_cpu_has_avx512()) {
1852
2481
  if (cur->ne[1] % 8 == 0) {
@@ -1864,6 +2493,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1864
2493
  return &iq4_nl_4x4_q8_0;
1865
2494
  }
1866
2495
  }
2496
+ } else if (cur->type == GGML_TYPE_Q8_0) {
2497
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
2498
+ if (cur->ne[1] % 4 == 0) {
2499
+ return &q8_0_4x8_q8_0;
2500
+ }
2501
+ }
2502
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
2503
+ if (cur->ne[1] % 4 == 0) {
2504
+ return &q8_0_4x4_q8_0;
2505
+ }
2506
+ }
1867
2507
  }
1868
2508
 
1869
2509
  return nullptr;