whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,11 +1,11 @@
1
- #include "llama-quant.h"
1
+ #include "llama.h"
2
2
  #include "llama-impl.h"
3
3
  #include "llama-model.h"
4
4
  #include "llama-model-loader.h"
5
5
 
6
- #include <algorithm>
7
6
  #include <cmath>
8
7
  #include <cstring>
8
+ #include <string>
9
9
  #include <cinttypes>
10
10
  #include <fstream>
11
11
  #include <mutex>
@@ -13,10 +13,28 @@
13
13
  #include <thread>
14
14
  #include <unordered_map>
15
15
 
16
- // Quantization types. Changes to this struct must be replicated in quantize.cpp
17
- struct tensor_quantization {
16
+ // result of parsing --tensor-type option
17
+ // (changes to this struct must be reflected in tools/quantize/quantize.cpp)
18
+ struct tensor_type_option {
18
19
  std::string name;
19
- ggml_type quant = GGML_TYPE_COUNT;
20
+ ggml_type type = GGML_TYPE_COUNT;
21
+ };
22
+
23
+ // tensor categorization - used to avoid repeated string matching in quantization logic.
24
+ // this is different from LLM_TN - we want broad categories, not specific tensor names per arch.
25
+ enum class tensor_category {
26
+ TOKEN_EMBD,
27
+ ATTENTION_Q,
28
+ ATTENTION_V,
29
+ ATTENTION_K,
30
+ ATTENTION_QKV,
31
+ ATTENTION_KV_B,
32
+ ATTENTION_OUTPUT,
33
+ FFN_UP,
34
+ FFN_GATE,
35
+ FFN_DOWN,
36
+ OUTPUT,
37
+ OTHER
20
38
  };
21
39
 
22
40
  static void zeros(std::ofstream & file, size_t n) {
@@ -54,7 +72,7 @@ static std::string remap_layer(const std::string & orig_name, const std::vector<
54
72
  return orig_name;
55
73
  }
56
74
 
57
- static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
75
+ static std::string remap_imatrix(const std::string & orig_name, const std::map<int, std::string> & mapped) {
58
76
  if (mapped.empty()) {
59
77
  return orig_name;
60
78
  }
@@ -76,6 +94,73 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map<
76
94
  return orig_name;
77
95
  }
78
96
 
97
+ //
98
+ // helper functions for tensor name matching
99
+ //
100
+
101
+ static bool tensor_name_match_token_embd(const char * tensor_name) {
102
+ return std::strcmp(tensor_name, "token_embd.weight") == 0 ||
103
+ std::strcmp(tensor_name, "per_layer_token_embd.weight") == 0;
104
+ }
105
+
106
+ static bool tensor_name_match_output_weight(const char * tensor_name) {
107
+ return std::strcmp(tensor_name, "output.weight") == 0;
108
+ }
109
+
110
+ //
111
+ // tensor categorization for quantization
112
+ //
113
+ // (this is different from LLM_TN - we want broad categories, not specific tensor names per arch)
114
+ //
115
+
116
+ static tensor_category tensor_get_category(const std::string & tensor_name) {
117
+ if (tensor_name_match_output_weight(tensor_name.c_str())) {
118
+ return tensor_category::OUTPUT;
119
+ }
120
+ if (tensor_name_match_token_embd(tensor_name.c_str())) {
121
+ return tensor_category::TOKEN_EMBD;
122
+ }
123
+ if (tensor_name.find("attn_qkv.weight") != std::string::npos) {
124
+ return tensor_category::ATTENTION_QKV;
125
+ }
126
+ if (tensor_name.find("attn_kv_b.weight") != std::string::npos) {
127
+ return tensor_category::ATTENTION_KV_B;
128
+ }
129
+ if (tensor_name.find("attn_v.weight") != std::string::npos) {
130
+ return tensor_category::ATTENTION_V;
131
+ }
132
+ if (tensor_name.find("attn_k.weight") != std::string::npos) {
133
+ return tensor_category::ATTENTION_K;
134
+ }
135
+ if (tensor_name.find("attn_q.weight") != std::string::npos) {
136
+ return tensor_category::ATTENTION_Q;
137
+ }
138
+ if (tensor_name.find("attn_output.weight") != std::string::npos) {
139
+ return tensor_category::ATTENTION_OUTPUT;
140
+ }
141
+ if (tensor_name.find("ffn_up") != std::string::npos) {
142
+ return tensor_category::FFN_UP;
143
+ }
144
+ if (tensor_name.find("ffn_gate") != std::string::npos) {
145
+ return tensor_category::FFN_GATE;
146
+ }
147
+ if (tensor_name.find("ffn_down") != std::string::npos) {
148
+ return tensor_category::FFN_DOWN;
149
+ }
150
+ return tensor_category::OTHER;
151
+ }
152
+
153
+ // check if category is for attention-v-like tensors (more sensitive to quantization)
154
+ static bool category_is_attn_v(tensor_category cat) {
155
+ return cat == tensor_category::ATTENTION_V ||
156
+ cat == tensor_category::ATTENTION_QKV ||
157
+ cat == tensor_category::ATTENTION_KV_B;
158
+ }
159
+
160
+ //
161
+ // quantization state
162
+ //
163
+
79
164
  struct quantize_state_impl {
80
165
  const llama_model & model;
81
166
  const llama_model_quantize_params * params;
@@ -89,20 +174,42 @@ struct quantize_state_impl {
89
174
  int i_ffn_gate = 0;
90
175
  int i_ffn_up = 0;
91
176
 
92
- int n_k_quantized = 0;
93
177
  int n_fallback = 0;
94
178
 
95
179
  bool has_imatrix = false;
96
180
 
97
- // used to figure out if a model shares tok_embd with the output weight
98
- bool has_output = false;
181
+ // used to figure out if a model has tied embeddings (tok_embd shares weights with output)
182
+ bool has_tied_embeddings = true; // assume tied until we see output.weight
183
+
184
+ // tensor type override patterns (compiled once, used twice)
185
+ std::vector<std::pair<std::regex, ggml_type>> tensor_type_patterns;
186
+
187
+ quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params):
188
+ model(model), params(params)
189
+ {
190
+ // compile regex patterns once - they are expensive
191
+ if (params->tensor_types) {
192
+ const auto & tensor_types = *static_cast<const std::vector<tensor_type_option> *>(params->tensor_types);
193
+ for (const auto & [tname, qtype] : tensor_types) {
194
+ tensor_type_patterns.emplace_back(std::regex(tname), qtype);
195
+ }
196
+ }
197
+ }
198
+ };
99
199
 
100
- quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
101
- : model(model)
102
- , params(params)
103
- {}
200
+ // per-tensor metadata, computed in the preliminary loop and used in the main loop
201
+ struct tensor_metadata {
202
+ ggml_type target_type;
203
+ tensor_category category;
204
+ std::string remapped_imatrix_name;
205
+ bool allows_quantization;
206
+ bool requires_imatrix;
104
207
  };
105
208
 
209
+ //
210
+ // dequantization
211
+ //
212
+
106
213
  static void llama_tensor_dequantize_impl(
107
214
  ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
108
215
  const size_t nelements, const int nthread
@@ -175,12 +282,132 @@ static void llama_tensor_dequantize_impl(
175
282
  workers.clear();
176
283
  }
177
284
 
178
- static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
285
+ //
286
+ // do we allow this tensor to be quantized?
287
+ //
288
+
289
+ static bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_arch arch, const ggml_tensor * tensor) {
290
+ // trivial checks first -- no string ops needed
291
+ if (params->only_copy) return false;
292
+
293
+ // quantize only 2D and 3D tensors (experts)
294
+ if (ggml_n_dims(tensor) < 2) return false;
295
+
296
+ const std::string name = ggml_get_name(tensor);
297
+
298
+ // This used to be a regex, but <regex> has an extreme cost to compile times.
299
+ bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
300
+
301
+ // do not quantize norm tensors
302
+ quantize &= name.find("_norm.weight") == std::string::npos;
303
+
304
+ quantize &= params->quantize_output_tensor || name != "output.weight";
305
+
306
+ // do not quantize expert gating tensors
307
+ // NOTE: can't use LLM_TN here because the layer number is not known
308
+ quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
309
+
310
+ // these are very small (e.g. 4x4)
311
+ quantize &= name.find("altup") == std::string::npos;
312
+ quantize &= name.find("laurel") == std::string::npos;
313
+
314
+ // these are not too big so keep them as it is
315
+ quantize &= name.find("per_layer_model_proj") == std::string::npos;
316
+
317
+ // do not quantize positional embeddings and token types (BERT)
318
+ quantize &= name != LLM_TN(arch)(LLM_TENSOR_POS_EMBD, "weight");
319
+ quantize &= name != LLM_TN(arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
320
+
321
+ // do not quantize Mamba/Kimi's small conv1d weights
322
+ // NOTE: can't use LLM_TN here because the layer number is not known
323
+ quantize &= name.find("ssm_conv1d") == std::string::npos;
324
+ quantize &= name.find("shortconv.conv.weight") == std::string::npos;
325
+
326
+ // do not quantize RWKV's small yet 2D weights
327
+ quantize &= name.find("time_mix_first.weight") == std::string::npos;
328
+ quantize &= name.find("time_mix_w0.weight") == std::string::npos;
329
+ quantize &= name.find("time_mix_w1.weight") == std::string::npos;
330
+ quantize &= name.find("time_mix_w2.weight") == std::string::npos;
331
+ quantize &= name.find("time_mix_v0.weight") == std::string::npos;
332
+ quantize &= name.find("time_mix_v1.weight") == std::string::npos;
333
+ quantize &= name.find("time_mix_v2.weight") == std::string::npos;
334
+ quantize &= name.find("time_mix_a0.weight") == std::string::npos;
335
+ quantize &= name.find("time_mix_a1.weight") == std::string::npos;
336
+ quantize &= name.find("time_mix_a2.weight") == std::string::npos;
337
+ quantize &= name.find("time_mix_g1.weight") == std::string::npos;
338
+ quantize &= name.find("time_mix_g2.weight") == std::string::npos;
339
+ quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
340
+ quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
341
+ quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
342
+
343
+ // do not quantize relative position bias (T5)
344
+ quantize &= name.find("attn_rel_b.weight") == std::string::npos;
345
+
346
+ // do not quantize specific multimodal tensors
347
+ quantize &= name.find(".position_embd.") == std::string::npos;
348
+
349
+ return quantize;
350
+ }
351
+
352
+ //
353
+ // tensor type selection
354
+ //
355
+
356
+ // incompatible tensor shapes are handled here - fallback to a compatible type
357
+ static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tensor * t, const ggml_type target_type) {
358
+ ggml_type return_type = target_type;
359
+
360
+ const int64_t ncols = t->ne[0];
361
+ const int64_t qk_k = ggml_blck_size(target_type);
362
+
363
+ if (ncols % qk_k != 0) { // this tensor's shape is incompatible with this quant
364
+ LLAMA_LOG_WARN("warning: %-36s - ncols %6" PRId64 " not divisible by %3" PRId64 " (required for type %7s) ",
365
+ t->name, ncols, qk_k, ggml_type_name(target_type));
366
+ ++qs.n_fallback;
367
+
368
+ switch (target_type) {
369
+ // types on the left: block size 256
370
+ case GGML_TYPE_IQ1_S:
371
+ case GGML_TYPE_IQ1_M:
372
+ case GGML_TYPE_IQ2_XXS:
373
+ case GGML_TYPE_IQ2_XS:
374
+ case GGML_TYPE_IQ2_S:
375
+ case GGML_TYPE_IQ3_XXS:
376
+ case GGML_TYPE_IQ3_S: // types on the right: block size 32
377
+ case GGML_TYPE_IQ4_XS: return_type = GGML_TYPE_IQ4_NL; break;
378
+ case GGML_TYPE_Q2_K:
379
+ case GGML_TYPE_Q3_K:
380
+ case GGML_TYPE_TQ1_0:
381
+ case GGML_TYPE_TQ2_0: return_type = GGML_TYPE_Q4_0; break;
382
+ case GGML_TYPE_Q4_K: return_type = GGML_TYPE_Q5_0; break;
383
+ case GGML_TYPE_Q5_K: return_type = GGML_TYPE_Q5_1; break;
384
+ case GGML_TYPE_Q6_K: return_type = GGML_TYPE_Q8_0; break;
385
+ default:
386
+ throw std::runtime_error(format("no tensor type fallback is defined for type %s",
387
+ ggml_type_name(target_type)));
388
+ }
389
+ if (ncols % ggml_blck_size(return_type) != 0) {
390
+ //
391
+ // the fallback return type is still not compatible for this tensor!
392
+ //
393
+ // most likely, this tensor's first dimension is not divisible by 32.
394
+ // this is very rare. we can either abort the quantization, or
395
+ // fallback to F16 / F32.
396
+ //
397
+ LLAMA_LOG_WARN("(WARNING: must use F16 due to unusual shape) ");
398
+ return_type = GGML_TYPE_F16;
399
+ }
400
+ LLAMA_LOG_WARN("-> falling back to %7s\n", ggml_type_name(return_type));
401
+ }
402
+ return return_type;
403
+ }
404
+
405
+ // internal standard logic for selecting the target tensor type based on tensor category, ftype, and model arch
406
+ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype, tensor_category category) {
179
407
  const std::string name = ggml_get_name(tensor);
180
408
 
181
409
  // TODO: avoid hardcoded tensor names - use the TN_* constants
182
410
  const llm_arch arch = qs.model.arch;
183
- const auto tn = LLM_TN(arch);
184
411
 
185
412
  auto use_more_bits = [](int i_layer, int n_layers) -> bool {
186
413
  return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
@@ -204,7 +431,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
204
431
 
205
432
  // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
206
433
  // with the quantization of the output tensor
207
- if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
434
+ if (category == tensor_category::OUTPUT || (qs.has_tied_embeddings && category == tensor_category::TOKEN_EMBD)) {
208
435
  if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
209
436
  new_type = qs.params->output_tensor_type;
210
437
  } else {
@@ -234,7 +461,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
234
461
  } else {
235
462
  new_type = GGML_TYPE_Q8_0;
236
463
  }
237
- } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
464
+ } else if (category == tensor_category::TOKEN_EMBD) {
238
465
  if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
239
466
  new_type = qs.params->token_embedding_type;
240
467
  } else {
@@ -254,21 +481,21 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
254
481
  }
255
482
  } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
256
483
  ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
257
- if (name.find("attn_v.weight") != std::string::npos) {
484
+ if (category_is_attn_v(category)) {
258
485
  if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
259
486
  else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
260
487
  ++qs.i_attention_wv;
261
488
  }
262
- else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
489
+ else if (qs.model.hparams.n_expert == 8 && category == tensor_category::ATTENTION_K) {
263
490
  new_type = GGML_TYPE_Q4_K;
264
491
  }
265
- else if (name.find("ffn_down") != std::string::npos) {
492
+ else if (category == tensor_category::FFN_DOWN) {
266
493
  if (qs.i_ffn_down < qs.n_ffn_down/8) {
267
494
  new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
268
495
  }
269
496
  ++qs.i_ffn_down;
270
497
  }
271
- else if (name.find("attn_output.weight") != std::string::npos) {
498
+ else if (category == tensor_category::ATTENTION_OUTPUT) {
272
499
  if (qs.model.hparams.n_expert == 8) {
273
500
  new_type = GGML_TYPE_Q5_K;
274
501
  } else {
@@ -276,7 +503,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
276
503
  else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
277
504
  }
278
505
  }
279
- } else if (name.find("attn_v.weight") != std::string::npos) {
506
+ } else if (category_is_attn_v(category)) {
280
507
  if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
281
508
  new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
282
509
  }
@@ -314,7 +541,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
314
541
  new_type = GGML_TYPE_Q8_0;
315
542
  }
316
543
  ++qs.i_attention_wv;
317
- } else if (name.find("attn_k.weight") != std::string::npos) {
544
+ } else if (category == tensor_category::ATTENTION_K) {
318
545
  if (qs.model.hparams.n_expert == 8) {
319
546
  // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
320
547
  // TODO: explore better strategies
@@ -326,14 +553,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
326
553
  else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
327
554
  new_type = GGML_TYPE_IQ2_S;
328
555
  }
329
- } else if (name.find("attn_q.weight") != std::string::npos) {
556
+ } else if (category == tensor_category::ATTENTION_Q) {
330
557
  if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
331
558
  new_type = GGML_TYPE_IQ3_XXS;
332
559
  }
333
560
  else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
334
561
  new_type = GGML_TYPE_IQ2_S;
335
562
  }
336
- } else if (name.find("ffn_down") != std::string::npos) {
563
+ } else if (category == tensor_category::FFN_DOWN) {
337
564
  auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
338
565
  int i_layer = info.first, n_layer = info.second;
339
566
  if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
@@ -378,7 +605,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
378
605
  new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
379
606
  }
380
607
  ++qs.i_ffn_down;
381
- } else if (name.find("attn_output.weight") != std::string::npos) {
608
+ } else if (category == tensor_category::ATTENTION_OUTPUT) {
382
609
  if (arch != LLM_ARCH_FALCON) {
383
610
  if (qs.model.hparams.n_expert == 8) {
384
611
  if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
@@ -398,14 +625,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
398
625
  if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
399
626
  }
400
627
  }
401
- else if (name.find("attn_qkv.weight") != std::string::npos) {
628
+ else if (category == tensor_category::ATTENTION_QKV) {
402
629
  if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
403
630
  new_type = GGML_TYPE_Q4_K;
404
631
  }
405
632
  else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
406
633
  else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
407
634
  }
408
- else if (name.find("ffn_gate") != std::string::npos) {
635
+ else if (category == tensor_category::FFN_GATE) {
409
636
  auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
410
637
  int i_layer = info.first, n_layer = info.second;
411
638
  if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
@@ -413,7 +640,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
413
640
  }
414
641
  ++qs.i_ffn_gate;
415
642
  }
416
- else if (name.find("ffn_up") != std::string::npos) {
643
+ else if (category == tensor_category::FFN_UP) {
417
644
  auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
418
645
  int i_layer = info.first, n_layer = info.second;
419
646
  if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
@@ -422,60 +649,58 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
422
649
  ++qs.i_ffn_up;
423
650
  }
424
651
 
425
- // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
426
- //}
427
- // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
428
- //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
429
- // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
430
- //}
431
- // This can be used to reduce the size of the Q5_K_S model.
432
- // The associated PPL increase is fully in line with the size reduction
433
- //else {
434
- // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
435
- //}
436
- bool convert_incompatible_tensor = false;
437
- {
438
- const int64_t nx = tensor->ne[0];
439
- const int64_t ny = tensor->ne[1];
440
- const int64_t qk_k = ggml_blck_size(new_type);
652
+ return new_type;
653
+ }
441
654
 
442
- if (nx % qk_k != 0) {
443
- LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
444
- convert_incompatible_tensor = true;
445
- } else {
446
- ++qs.n_k_quantized;
447
- }
655
+ // outer wrapper: determine the ggml_type that this tensor should be quantized to
656
+ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm) {
657
+ if (!tensor_allows_quantization(params, qs.model.arch, tensor)) {
658
+ return tensor->type;
659
+ }
660
+ if (params->token_embedding_type < GGML_TYPE_COUNT && tm.category == tensor_category::TOKEN_EMBD) {
661
+ return params->token_embedding_type;
662
+ }
663
+ if (params->output_tensor_type < GGML_TYPE_COUNT && tm.category == tensor_category::OUTPUT) {
664
+ return params->output_tensor_type;
448
665
  }
449
666
 
450
- if (convert_incompatible_tensor) {
451
- switch (new_type) {
452
- case GGML_TYPE_TQ1_0:
453
- case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
454
- case GGML_TYPE_IQ2_XXS:
455
- case GGML_TYPE_IQ2_XS:
456
- case GGML_TYPE_IQ2_S:
457
- case GGML_TYPE_IQ3_XXS:
458
- case GGML_TYPE_IQ3_S:
459
- case GGML_TYPE_IQ1_S:
460
- case GGML_TYPE_IQ1_M:
461
- case GGML_TYPE_Q2_K:
462
- case GGML_TYPE_Q3_K:
463
- case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
464
- case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
465
- case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
466
- case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
467
- default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
667
+ ggml_type new_type = default_type;
668
+
669
+ // get more optimal quantization type based on the tensor shape, layer, etc.
670
+ if (!params->pure && ggml_is_quantized(default_type)) {
671
+ // if the user provided tensor types - use those
672
+ bool manual = false;
673
+ if (!qs.tensor_type_patterns.empty()) {
674
+ const std::string tensor_name(tensor->name);
675
+ for (const auto & [pattern, qtype] : qs.tensor_type_patterns) {
676
+ if (std::regex_search(tensor_name, pattern)) {
677
+ if (qtype != new_type) {
678
+ LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n",
679
+ __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype));
680
+ new_type = qtype;
681
+ manual = true;
682
+ break;
683
+ }
684
+ }
685
+ }
468
686
  }
469
- if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
470
- new_type = GGML_TYPE_F16;
687
+
688
+ // if not manual - use the standard logic for choosing the quantization type based on the selected mixture
689
+ if (!manual) {
690
+ new_type = llama_tensor_get_type_impl(qs, new_type, tensor, params->ftype, tm.category);
471
691
  }
472
- LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
473
- ++qs.n_fallback;
692
+
693
+ // incompatible tensor shapes are handled here - fallback to a compatible type
694
+ new_type = tensor_type_fallback(qs, tensor, new_type);
474
695
  }
475
696
 
476
697
  return new_type;
477
698
  }
478
699
 
700
+ //
701
+ // quantization implementation
702
+ //
703
+
479
704
  static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
480
705
  if (nthread < 2) {
481
706
  // single-thread
@@ -530,50 +755,85 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
530
755
  return new_size;
531
756
  }
532
757
 
533
- static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
534
- ggml_type default_type;
535
- llama_ftype ftype = params->ftype;
758
+ //
759
+ // imatrix requirement check
760
+ //
761
+
762
+ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type dst_type, const llama_ftype ftype) {
763
+ if (tensor_name_match_token_embd(tensor_name) || tensor_name_match_output_weight(tensor_name)) {
764
+ return false;
765
+ }
766
+ switch (dst_type) {
767
+ case GGML_TYPE_IQ3_XXS:
768
+ case GGML_TYPE_IQ2_XXS:
769
+ case GGML_TYPE_IQ2_XS:
770
+ case GGML_TYPE_IQ2_S:
771
+ case GGML_TYPE_IQ1_M:
772
+ case GGML_TYPE_IQ1_S:
773
+ return true;
774
+ case GGML_TYPE_Q2_K:
775
+ // as a general rule, the k-type quantizations don't require imatrix data.
776
+ // the only exception is Q2_K tensors that are part of a Q2_K_S file.
777
+ return ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S;
778
+ default:
779
+ return false;
780
+ }
781
+ }
782
+
783
+ //
784
+ // given a file type, get the default tensor type
785
+ //
536
786
 
537
- switch (params->ftype) {
538
- case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
539
- case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
540
- case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
541
- case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
542
- case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
543
- case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
544
- case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
545
- case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
787
+ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
788
+ switch (ftype) {
789
+ case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0;
790
+ case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1;
791
+ case LLAMA_FTYPE_MOSTLY_Q5_0: return GGML_TYPE_Q5_0;
792
+ case LLAMA_FTYPE_MOSTLY_Q5_1: return GGML_TYPE_Q5_1;
793
+ case LLAMA_FTYPE_MOSTLY_Q8_0: return GGML_TYPE_Q8_0;
794
+ case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16;
795
+ case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16;
796
+ case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32;
546
797
 
547
- case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break;
798
+ case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4;
548
799
 
549
800
  // K-quants
550
801
  case LLAMA_FTYPE_MOSTLY_Q2_K_S:
551
- case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
552
- case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
802
+ case LLAMA_FTYPE_MOSTLY_Q2_K: return GGML_TYPE_Q2_K;
803
+ case LLAMA_FTYPE_MOSTLY_IQ3_XS: return GGML_TYPE_IQ3_S;
553
804
  case LLAMA_FTYPE_MOSTLY_Q3_K_S:
554
805
  case LLAMA_FTYPE_MOSTLY_Q3_K_M:
555
- case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
806
+ case LLAMA_FTYPE_MOSTLY_Q3_K_L: return GGML_TYPE_Q3_K;
556
807
  case LLAMA_FTYPE_MOSTLY_Q4_K_S:
557
- case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
808
+ case LLAMA_FTYPE_MOSTLY_Q4_K_M: return GGML_TYPE_Q4_K;
558
809
  case LLAMA_FTYPE_MOSTLY_Q5_K_S:
559
- case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
560
- case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
561
- case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
562
- case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
563
- case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
564
- case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
565
- case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
566
- case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
567
- case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
568
- case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
569
- case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
570
- case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
571
- case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
572
- case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
573
- case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
810
+ case LLAMA_FTYPE_MOSTLY_Q5_K_M: return GGML_TYPE_Q5_K;
811
+ case LLAMA_FTYPE_MOSTLY_Q6_K: return GGML_TYPE_Q6_K;
812
+ case LLAMA_FTYPE_MOSTLY_TQ1_0: return GGML_TYPE_TQ1_0;
813
+ case LLAMA_FTYPE_MOSTLY_TQ2_0: return GGML_TYPE_TQ2_0;
814
+ case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return GGML_TYPE_IQ2_XXS;
815
+ case LLAMA_FTYPE_MOSTLY_IQ2_XS: return GGML_TYPE_IQ2_XS;
816
+ case LLAMA_FTYPE_MOSTLY_IQ2_S: return GGML_TYPE_IQ2_XS;
817
+ case LLAMA_FTYPE_MOSTLY_IQ2_M: return GGML_TYPE_IQ2_S;
818
+ case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return GGML_TYPE_IQ3_XXS;
819
+ case LLAMA_FTYPE_MOSTLY_IQ1_S: return GGML_TYPE_IQ1_S;
820
+ case LLAMA_FTYPE_MOSTLY_IQ1_M: return GGML_TYPE_IQ1_M;
821
+ case LLAMA_FTYPE_MOSTLY_IQ4_NL: return GGML_TYPE_IQ4_NL;
822
+ case LLAMA_FTYPE_MOSTLY_IQ4_XS: return GGML_TYPE_IQ4_XS;
823
+ case LLAMA_FTYPE_MOSTLY_IQ3_S:
824
+ case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S;
574
825
 
575
826
  default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
576
827
  }
828
+ }
829
+
830
+ //
831
+ // main quantization driver
832
+ //
833
+
834
+ static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
835
+ ggml_type default_type;
836
+ llama_ftype ftype = params->ftype;
577
837
 
578
838
  int nthread = params->nthread;
579
839
 
@@ -581,6 +841,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
581
841
  nthread = std::thread::hardware_concurrency();
582
842
  }
583
843
 
844
+ default_type = llama_ftype_get_default_type(ftype);
845
+
584
846
  // mmap consistently increases speed on Linux, and also increases speed on Windows with
585
847
  // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
586
848
  #if defined(__linux__) || defined(_WIN32)
@@ -596,7 +858,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
596
858
  }
597
859
 
598
860
  std::vector<std::string> splits = {};
599
- llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
861
+ llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr,
862
+ fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
600
863
  ml.init_mappings(false); // no prefetching
601
864
 
602
865
  llama_model model(llama_model_default_params());
@@ -614,7 +877,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
614
877
  if (params->imatrix) {
615
878
  imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
616
879
  if (imatrix_data) {
617
- LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
880
+ LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n",
881
+ __func__, (int)imatrix_data->size());
618
882
  qs.has_imatrix = true;
619
883
  // check imatrix for nans or infs
620
884
  for (const auto & kv : *imatrix_data) {
@@ -636,7 +900,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
636
900
  }
637
901
 
638
902
  // copy the KV pairs from the input file
639
- gguf_set_kv (ctx_out.get(), ml.meta.get());
903
+ gguf_set_kv (ctx_out.get(), ml.metadata);
640
904
  gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
641
905
  gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
642
906
 
@@ -653,7 +917,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
653
917
  gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
654
918
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
655
919
  // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
656
- gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64));
920
+ gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64));
657
921
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
658
922
  gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
659
923
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
@@ -666,7 +930,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
666
930
 
667
931
  std::map<int, std::string> mapped;
668
932
  int blk_id = 0;
669
- int pruned_attention_w = 0;
670
933
 
671
934
  // make a list of weights
672
935
  std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
@@ -674,14 +937,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
674
937
  for (const auto & it : ml.weights_map) {
675
938
  const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
676
939
  if (remapped_name.empty()) {
677
- if (it.first.find("attn_v.weight") != std::string::npos ||
678
- it.first.find("attn_qkv.weight") != std::string::npos ||
679
- it.first.find("attn_kv_b.weight") != std::string::npos) {
680
- pruned_attention_w++;
681
- }
682
940
  LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
683
941
  continue;
684
- } else if (remapped_name != it.first) {
942
+ }
943
+
944
+ if (remapped_name != it.first) {
685
945
  ggml_set_name(it.second.tensor, remapped_name.c_str());
686
946
  LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
687
947
  }
@@ -701,49 +961,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
701
961
  });
702
962
  }
703
963
 
704
- for (const auto * it : tensors) {
705
- const struct ggml_tensor * tensor = it->tensor;
706
-
707
- const std::string name = ggml_get_name(tensor);
708
-
709
- // TODO: avoid hardcoded tensor names - use the TN_* constants
710
- if (name.find("attn_v.weight") != std::string::npos ||
711
- name.find("attn_qkv.weight") != std::string::npos ||
712
- name.find("attn_kv_b.weight")!= std::string::npos) {
713
- ++qs.n_attention_wv;
714
- } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
715
- qs.has_output = true;
716
- }
717
- }
718
-
719
- qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
720
-
721
- // sanity checks for models that have attention layers
722
- if (qs.n_attention_wv != 0)
723
- {
724
- const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
725
- // attention layers have a non-zero number of kv heads
726
- int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
727
- if (llama_model_has_encoder(&model)) {
728
- // now n_attn_layer is the number of attention layers in the encoder
729
- // for each decoder block, there are 2 attention layers
730
- n_attn_layer += 2 * model.hparams.dec_n_layer;
731
- }
732
- GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
733
- }
734
-
735
- size_t total_size_org = 0;
736
- size_t total_size_new = 0;
737
-
738
- std::vector<std::thread> workers;
739
- workers.reserve(nthread);
740
-
741
964
  int idx = 0;
742
-
743
- std::vector<no_init<uint8_t>> read_data;
744
- std::vector<no_init<uint8_t>> work;
745
- std::vector<no_init<float>> f32_conv_buf;
746
-
747
965
  uint16_t n_split = 1;
748
966
 
749
967
  // Assume split index is continuous
@@ -755,14 +973,68 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
755
973
  std::vector<gguf_context_ptr> ctx_outs(n_split);
756
974
  ctx_outs[0] = std::move(ctx_out);
757
975
 
758
- // populate the original tensors so we get an initial meta data
759
- for (const auto * it : tensors) {
976
+ // compute tensor metadata once and cache it
977
+ std::vector<tensor_metadata> metadata(tensors.size());
978
+
979
+ // initialize quantization state before preliminary loop (counters for use_more_bits)
980
+ {
981
+ for (size_t i = 0; i < tensors.size(); ++i) {
982
+ const auto cat = tensor_get_category(tensors[i]->tensor->name);
983
+ if (category_is_attn_v(cat)) {
984
+ ++qs.n_attention_wv;
985
+ }
986
+ if (cat == tensor_category::OUTPUT) {
987
+ qs.has_tied_embeddings = false;
988
+ }
989
+ metadata[i].category = cat; // save and re-use the category while we're at it
990
+ }
991
+ // these also need to be set to n_layer by default
992
+ qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
993
+ }
994
+
995
+ // flag for --dry-run
996
+ bool will_require_imatrix = false;
997
+
998
+ //
999
+ // preliminary iteration over all weights
1000
+ //
1001
+
1002
+ for (size_t i = 0; i < tensors.size(); ++i) {
1003
+ const auto * it = tensors[i];
1004
+ const struct ggml_tensor * tensor = it->tensor;
1005
+ const std::string name = ggml_get_name(tensor);
1006
+
760
1007
  uint16_t i_split = params->keep_split ? it->idx : 0;
761
- ggml_tensor * tensor = it->tensor;
762
1008
  if (!ctx_outs[i_split]) {
763
1009
  ctx_outs[i_split].reset(gguf_init_empty());
764
1010
  }
765
1011
  gguf_add_tensor(ctx_outs[i_split].get(), tensor);
1012
+
1013
+ metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor);
1014
+
1015
+ if (metadata[i].allows_quantization) {
1016
+ metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]);
1017
+ } else {
1018
+ metadata[i].target_type = tensor->type;
1019
+ }
1020
+
1021
+ metadata[i].requires_imatrix = tensor_requires_imatrix(tensor->name, metadata[i].target_type, ftype);
1022
+
1023
+ if (params->imatrix) {
1024
+ metadata[i].remapped_imatrix_name = remap_imatrix(tensor->name, mapped);
1025
+ } else if (metadata[i].allows_quantization && metadata[i].requires_imatrix) {
1026
+ if (params->dry_run) {
1027
+ will_require_imatrix = true;
1028
+ } else {
1029
+ LLAMA_LOG_ERROR("\n============================================================================\n"
1030
+ " ERROR: this quantization requires an importance matrix!\n"
1031
+ " - offending tensor: %s\n"
1032
+ " - target type: %s\n"
1033
+ "============================================================================\n\n",
1034
+ name.c_str(), ggml_type_name(metadata[i].target_type));
1035
+ throw std::runtime_error("this quantization requires an imatrix!");
1036
+ }
1037
+ }
766
1038
  }
767
1039
 
768
1040
  // Set split info if needed
@@ -774,6 +1046,16 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
774
1046
  }
775
1047
  }
776
1048
 
1049
+ size_t total_size_org = 0;
1050
+ size_t total_size_new = 0;
1051
+
1052
+ std::vector<std::thread> workers;
1053
+ workers.reserve(nthread);
1054
+
1055
+ std::vector<no_init<uint8_t>> read_data;
1056
+ std::vector<no_init<uint8_t>> work;
1057
+ std::vector<no_init<float>> f32_conv_buf;
1058
+
777
1059
  int cur_split = -1;
778
1060
  std::ofstream fout;
779
1061
  auto close_ofstream = [&]() {
@@ -803,248 +1085,182 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
803
1085
  ::zeros(fout, meta_size);
804
1086
  };
805
1087
 
806
- const auto tn = LLM_TN(model.arch);
807
- new_ofstream(0);
808
- for (const auto * it : tensors) {
809
- const auto & weight = *it;
1088
+ // no output file for --dry-run
1089
+ if (!params->dry_run) {
1090
+ new_ofstream(0);
1091
+ }
1092
+
1093
+ //
1094
+ // main loop: iterate over all weights
1095
+ //
1096
+
1097
+ for (size_t i = 0; i < tensors.size(); ++i) {
1098
+ const auto & weight = *tensors[i];
1099
+ const auto & tm = metadata[i];
810
1100
  ggml_tensor * tensor = weight.tensor;
811
- if (weight.idx != cur_split && params->keep_split) {
1101
+
1102
+ if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) {
812
1103
  close_ofstream();
813
1104
  new_ofstream(weight.idx);
814
1105
  }
815
1106
 
816
1107
  const std::string name = ggml_get_name(tensor);
1108
+ const size_t tensor_size = ggml_nbytes(tensor);
817
1109
 
818
- if (!ml.use_mmap) {
819
- if (read_data.size() < ggml_nbytes(tensor)) {
820
- read_data.resize(ggml_nbytes(tensor));
1110
+ if (!params->dry_run) {
1111
+ if (!ml.use_mmap) {
1112
+ if (read_data.size() < tensor_size) {
1113
+ read_data.resize(tensor_size);
1114
+ }
1115
+ tensor->data = read_data.data();
821
1116
  }
822
- tensor->data = read_data.data();
1117
+ ml.load_data_for(tensor);
823
1118
  }
824
- ml.load_data_for(tensor);
825
1119
 
826
- LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
1120
+ LLAMA_LOG_INFO("[%4d/%4d] %-36s - [%s], type = %6s, ",
827
1121
  ++idx, ml.n_tensors,
828
1122
  ggml_get_name(tensor),
829
1123
  llama_format_tensor_shape(tensor).c_str(),
830
1124
  ggml_type_name(tensor->type));
831
1125
 
832
- // This used to be a regex, but <regex> has an extreme cost to compile times.
833
- bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
834
-
835
- // quantize only 2D and 3D tensors (experts)
836
- quantize &= (ggml_n_dims(tensor) >= 2);
837
-
838
- // do not quantize norm tensors
839
- quantize &= name.find("_norm.weight") == std::string::npos;
840
-
841
- quantize &= params->quantize_output_tensor || name != "output.weight";
842
- quantize &= !params->only_copy;
843
-
844
- // do not quantize expert gating tensors
845
- // NOTE: can't use LLM_TN here because the layer number is not known
846
- quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
847
-
848
- // these are very small (e.g. 4x4)
849
- quantize &= name.find("altup") == std::string::npos;
850
- quantize &= name.find("laurel") == std::string::npos;
851
-
852
- // these are not too big so keep them as it is
853
- quantize &= name.find("per_layer_model_proj") == std::string::npos;
854
-
855
- // do not quantize positional embeddings and token types (BERT)
856
- quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
857
- quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
858
-
859
- // do not quantize Mamba's small yet 2D weights
860
- // NOTE: can't use LLM_TN here because the layer number is not known
861
- quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
862
- quantize &= name.find("shortconv.conv.weight") == std::string::npos;
863
-
864
- // do not quantize RWKV's small yet 2D weights
865
- quantize &= name.find("time_mix_first.weight") == std::string::npos;
866
- quantize &= name.find("time_mix_w0.weight") == std::string::npos;
867
- quantize &= name.find("time_mix_w1.weight") == std::string::npos;
868
- quantize &= name.find("time_mix_w2.weight") == std::string::npos;
869
- quantize &= name.find("time_mix_v0.weight") == std::string::npos;
870
- quantize &= name.find("time_mix_v1.weight") == std::string::npos;
871
- quantize &= name.find("time_mix_v2.weight") == std::string::npos;
872
- quantize &= name.find("time_mix_a0.weight") == std::string::npos;
873
- quantize &= name.find("time_mix_a1.weight") == std::string::npos;
874
- quantize &= name.find("time_mix_a2.weight") == std::string::npos;
875
- quantize &= name.find("time_mix_g1.weight") == std::string::npos;
876
- quantize &= name.find("time_mix_g2.weight") == std::string::npos;
877
- quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
878
- quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
879
- quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
880
-
881
- // do not quantize relative position bias (T5)
882
- quantize &= name.find("attn_rel_b.weight") == std::string::npos;
883
-
884
- ggml_type new_type;
1126
+ const ggml_type cur_type = tensor->type;
1127
+ const ggml_type new_type = tm.target_type;
1128
+
1129
+ // If we've decided to quantize to the same type the tensor is already
1130
+ // in then there's nothing to do.
1131
+ bool quantize = cur_type != new_type;
1132
+
885
1133
  void * new_data;
886
1134
  size_t new_size;
887
1135
 
888
- if (quantize) {
889
- new_type = default_type;
890
-
891
- // get more optimal quantization type based on the tensor shape, layer, etc.
892
- if (!params->pure && ggml_is_quantized(default_type)) {
893
- int fallback = qs.n_fallback;
894
- new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
895
- // unless the user specifies a type, and the tensor geometry will not require fallback quantisation
896
- if (params->tensor_types && qs.n_fallback - fallback == 0) {
897
- const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
898
- const std::string tensor_name(tensor->name);
899
- for (const auto & [tname, qtype] : tensor_types) {
900
- if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
901
- if (qtype != new_type) {
902
- LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
903
- new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
904
- }
905
- }
906
- }
1136
+ if (params->dry_run) {
1137
+ // the --dry-run option calculates the final quantization size without quantizing
1138
+ if (quantize) {
1139
+ new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]);
1140
+ LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n",
1141
+ tensor_size/1024.0/1024.0,
1142
+ new_size/1024.0/1024.0,
1143
+ ggml_type_name(new_type));
1144
+ if (!will_require_imatrix && tm.requires_imatrix) {
1145
+ will_require_imatrix = true;
907
1146
  }
1147
+ } else {
1148
+ new_size = tensor_size;
1149
+ LLAMA_LOG_INFO("size = %8.3f MiB\n", new_size/1024.0/1024.0);
908
1150
  }
909
- if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
910
- new_type = params->token_embedding_type;
911
- }
912
- if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
913
- new_type = params->output_tensor_type;
914
- }
915
-
916
- // If we've decided to quantize to the same type the tensor is already
917
- // in then there's nothing to do.
918
- quantize = tensor->type != new_type;
919
- }
920
-
921
- if (!quantize) {
922
- new_type = tensor->type;
923
- new_data = tensor->data;
924
- new_size = ggml_nbytes(tensor);
925
- LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0);
1151
+ total_size_org += tensor_size;
1152
+ total_size_new += new_size;
1153
+ continue;
926
1154
  } else {
927
- const int64_t nelements = ggml_nelements(tensor);
1155
+ // no --dry-run, perform quantization
1156
+ if (!quantize) {
1157
+ new_data = tensor->data;
1158
+ new_size = tensor_size;
1159
+ LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0);
1160
+ } else {
1161
+ const int64_t nelements = ggml_nelements(tensor);
928
1162
 
929
- const float * imatrix = nullptr;
930
- if (imatrix_data) {
931
- auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
932
- if (it == imatrix_data->end()) {
933
- LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
934
- } else {
935
- if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
936
- imatrix = it->second.data();
1163
+ const float * imatrix = nullptr;
1164
+ if (imatrix_data) {
1165
+ auto it = imatrix_data->find(tm.remapped_imatrix_name);
1166
+ if (it == imatrix_data->end()) {
1167
+ LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
937
1168
  } else {
938
- LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
939
- int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
940
-
941
- // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
942
- // this is a significant error and it may be good idea to abort the process if this happens,
943
- // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
944
- // tok_embd should be ignored in this case, since it always causes this warning
945
- if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
946
- throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
947
- int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
1169
+ if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
1170
+ imatrix = it->second.data();
1171
+ } else {
1172
+ LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
1173
+ int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
1174
+
1175
+ // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
1176
+ // this is a significant error and it may be good idea to abort the process if this happens,
1177
+ // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
1178
+ // tok_embd should be ignored in this case, since it always causes this warning
1179
+ if (!tensor_name_match_token_embd(tensor->name)) {
1180
+ throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
1181
+ int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
1182
+ }
948
1183
  }
949
1184
  }
950
1185
  }
951
- }
952
- if ((new_type == GGML_TYPE_IQ2_XXS ||
953
- new_type == GGML_TYPE_IQ2_XS ||
954
- new_type == GGML_TYPE_IQ2_S ||
955
- new_type == GGML_TYPE_IQ1_S ||
956
- (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
957
- (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
958
- LLAMA_LOG_ERROR("\n\n============================================================\n");
959
- LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
960
- LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
961
- LLAMA_LOG_ERROR("============================================================\n\n");
962
- throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
963
- }
1186
+ if (!imatrix && tm.requires_imatrix) {
1187
+ LLAMA_LOG_ERROR("\n\n============================================================\n");
1188
+ LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
1189
+ LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
1190
+ LLAMA_LOG_ERROR("============================================================\n\n");
1191
+ throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
1192
+ }
964
1193
 
965
- float * f32_data;
1194
+ float * f32_data;
966
1195
 
967
- if (tensor->type == GGML_TYPE_F32) {
968
- f32_data = (float *) tensor->data;
969
- } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
970
- throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
971
- } else {
972
- llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
973
- f32_data = (float *) f32_conv_buf.data();
974
- }
1196
+ if (tensor->type == GGML_TYPE_F32) {
1197
+ f32_data = (float *) tensor->data;
1198
+ } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
1199
+ throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
1200
+ } else {
1201
+ llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
1202
+ f32_data = (float *) f32_conv_buf.data();
1203
+ }
975
1204
 
976
- LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
977
- fflush(stdout);
1205
+ LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
1206
+ fflush(stdout);
978
1207
 
979
- if (work.size() < (size_t)nelements * 4) {
980
- work.resize(nelements * 4); // upper bound on size
981
- }
982
- new_data = work.data();
983
-
984
- const int64_t n_per_row = tensor->ne[0];
985
- const int64_t nrows = tensor->ne[1];
986
-
987
- static const int64_t min_chunk_size = 32 * 512;
988
- const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
989
-
990
- const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
991
- const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
992
- const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
993
-
994
- // quantize each expert separately since they have different importance matrices
995
- new_size = 0;
996
- for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
997
- const float * f32_data_03 = f32_data + i03 * nelements_matrix;
998
- void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
999
- const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
1000
-
1001
- new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
1002
-
1003
- // TODO: temporary sanity check that the F16 -> MXFP4 is lossless
1004
- #if 0
1005
- if (new_type == GGML_TYPE_MXFP4) {
1006
- auto * x = f32_data_03;
1007
-
1008
- //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row);
1009
- std::vector<float> deq(nrows*n_per_row);
1010
- const ggml_type_traits * qtype = ggml_get_type_traits(new_type);
1011
- qtype->to_float(new_data_03, deq.data(), deq.size());
1012
-
1013
- double err = 0.0f;
1014
- for (int i = 0; i < (int) deq.size(); ++i) {
1015
- err += fabsf(deq[i] - x[i]);
1016
- //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) {
1017
- if (deq[i] != x[i]) {
1018
- LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]);
1019
- }
1020
- }
1021
- //LLAMA_LOG_INFO("err = %f\n", err);
1022
- GGML_ASSERT(err == 0.00000);
1208
+ if (work.size() < (size_t)nelements * 4) {
1209
+ work.resize(nelements * 4); // upper bound on size
1023
1210
  }
1024
- #endif
1025
- }
1026
- LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
1027
- }
1028
- total_size_org += ggml_nbytes(tensor);
1029
- total_size_new += new_size;
1211
+ new_data = work.data();
1212
+
1213
+ const int64_t n_per_row = tensor->ne[0];
1214
+ const int64_t nrows = tensor->ne[1];
1030
1215
 
1031
- // update the gguf meta data as we go
1032
- gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
1033
- GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
1034
- gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
1216
+ static const int64_t min_chunk_size = 32 * 512;
1217
+ const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
1035
1218
 
1036
- // write tensor data + padding
1037
- fout.write((const char *) new_data, new_size);
1038
- zeros(fout, GGML_PAD(new_size, align) - new_size);
1219
+ const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
1220
+ const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
1221
+ const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
1222
+
1223
+ // quantize each expert separately since they have different importance matrices
1224
+ new_size = 0;
1225
+ for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
1226
+ const float * f32_data_03 = f32_data + i03 * nelements_matrix;
1227
+ void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
1228
+ const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
1229
+
1230
+ new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
1231
+ }
1232
+ LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0);
1233
+ }
1234
+ total_size_org += tensor_size;
1235
+ total_size_new += new_size;
1236
+
1237
+ // update the gguf meta data as we go
1238
+ gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
1239
+ GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
1240
+ gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
1241
+
1242
+ // write tensor data + padding
1243
+ fout.write((const char *) new_data, new_size);
1244
+ zeros(fout, GGML_PAD(new_size, align) - new_size);
1245
+ } // no --dry-run
1246
+ } // main loop
1247
+
1248
+ if (!params->dry_run) {
1249
+ close_ofstream();
1039
1250
  }
1040
- close_ofstream();
1041
1251
 
1042
- LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0);
1043
- LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0);
1252
+ LLAMA_LOG_INFO("%s: model size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_org/1024.0/1024.0, total_size_org*8.0/ml.n_elements);
1253
+ LLAMA_LOG_INFO("%s: quant size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_new/1024.0/1024.0, total_size_new*8.0/ml.n_elements);
1254
+
1255
+ if (!params->imatrix && params->dry_run && will_require_imatrix) {
1256
+ LLAMA_LOG_WARN("%s: WARNING: dry run completed successfully, but actually completing this quantization will require an imatrix!\n",
1257
+ __func__
1258
+ );
1259
+ }
1044
1260
 
1045
1261
  if (qs.n_fallback > 0) {
1046
1262
  LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
1047
- __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
1263
+ __func__, qs.n_fallback, ml.n_tensors);
1048
1264
  }
1049
1265
  }
1050
1266
 
@@ -1063,6 +1279,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
1063
1279
  /*.only_copy =*/ false,
1064
1280
  /*.pure =*/ false,
1065
1281
  /*.keep_split =*/ false,
1282
+ /*.dry_run =*/ false,
1066
1283
  /*.imatrix =*/ nullptr,
1067
1284
  /*.kv_overrides =*/ nullptr,
1068
1285
  /*.tensor_type =*/ nullptr,