whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -61,9 +61,6 @@
61
61
  #define m512i(p) (__m512i)(p)
62
62
  #endif
63
63
 
64
- // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
65
- float ggml_table_f32_f16[1 << 16];
66
-
67
64
  #if defined(__linux__) || \
68
65
  defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
69
66
  (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
@@ -133,7 +130,7 @@ static void ggml_print_backtrace_symbols(void) {
133
130
  }
134
131
  #endif
135
132
 
136
- static void ggml_print_backtrace(void) {
133
+ void ggml_print_backtrace(void) {
137
134
  const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
138
135
  if (GGML_NO_BACKTRACE) {
139
136
  return;
@@ -160,6 +157,10 @@ static void ggml_print_backtrace(void) {
160
157
  const int parent_pid = getpid();
161
158
  const int child_pid = fork();
162
159
  if (child_pid < 0) { // error
160
+ #if defined(__linux__)
161
+ close(lock[1]);
162
+ close(lock[0]);
163
+ #endif
163
164
  return;
164
165
  } else if (child_pid == 0) { // child
165
166
  char attach[32];
@@ -167,6 +168,7 @@ static void ggml_print_backtrace(void) {
167
168
  #if defined(__linux__)
168
169
  close(lock[1]);
169
170
  (void) !read(lock[0], lock, 1);
171
+ close(lock[0]);
170
172
  #endif
171
173
  // try gdb
172
174
  execlp("gdb", "gdb", "--batch",
@@ -195,27 +197,44 @@ static void ggml_print_backtrace(void) {
195
197
  }
196
198
  }
197
199
  #else
198
- static void ggml_print_backtrace(void) {
200
+ void ggml_print_backtrace(void) {
199
201
  // platform not supported
200
202
  }
201
203
  #endif
202
204
 
205
+ static ggml_abort_callback_t g_abort_callback = NULL;
206
+
207
+ // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
208
+ GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) {
209
+ ggml_abort_callback_t ret_val = g_abort_callback;
210
+ g_abort_callback = callback;
211
+ return ret_val;
212
+ }
213
+
203
214
  void ggml_abort(const char * file, int line, const char * fmt, ...) {
204
215
  fflush(stdout);
205
216
 
206
- fprintf(stderr, "%s:%d: ", file, line);
217
+ char message[2048];
218
+ int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
207
219
 
208
220
  va_list args;
209
221
  va_start(args, fmt);
210
- vfprintf(stderr, fmt, args);
222
+ vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
211
223
  va_end(args);
212
224
 
213
- fprintf(stderr, "\n");
225
+ if (g_abort_callback) {
226
+ g_abort_callback(message);
227
+ } else {
228
+ // default: print error and backtrace to stderr
229
+ fprintf(stderr, "%s\n", message);
230
+ ggml_print_backtrace();
231
+ }
214
232
 
215
- ggml_print_backtrace();
216
233
  abort();
217
234
  }
218
235
 
236
+ // ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
237
+
219
238
  //
220
239
  // logging
221
240
  //
@@ -454,6 +473,14 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
454
473
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
455
474
  }
456
475
 
476
+ const char * ggml_version(void) {
477
+ return GGML_VERSION;
478
+ }
479
+
480
+ const char * ggml_commit(void) {
481
+ return GGML_COMMIT;
482
+ }
483
+
457
484
  //
458
485
  // timing
459
486
  //
@@ -555,9 +582,6 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
555
582
  #endif
556
583
 
557
584
  }
558
- static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
559
- static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
560
- static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
561
585
 
562
586
  static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
563
587
  [GGML_TYPE_I8] = {
@@ -663,6 +687,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
663
687
  .is_quantized = true,
664
688
  .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
665
689
  },
690
+ [GGML_TYPE_MXFP4] = {
691
+ .type_name = "mxfp4",
692
+ .blck_size = QK_MXFP4,
693
+ .type_size = sizeof(block_mxfp4),
694
+ .is_quantized = true,
695
+ .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
696
+ .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
697
+ },
666
698
  [GGML_TYPE_Q2_K] = {
667
699
  .type_name = "q2_K",
668
700
  .blck_size = QK_K,
@@ -881,12 +913,6 @@ struct ggml_context {
881
913
  struct ggml_object * objects_end;
882
914
  };
883
915
 
884
- struct ggml_context_container {
885
- bool used;
886
-
887
- struct ggml_context context;
888
- };
889
-
890
916
  //
891
917
  // data types
892
918
  //
@@ -896,6 +922,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
896
922
 
897
923
  "DUP",
898
924
  "ADD",
925
+ "ADD_ID",
899
926
  "ADD1",
900
927
  "ACC",
901
928
  "SUB",
@@ -935,6 +962,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
935
962
  "TRANSPOSE",
936
963
  "GET_ROWS",
937
964
  "GET_ROWS_BACK",
965
+ "SET_ROWS",
938
966
  "DIAG",
939
967
  "DIAG_MASK_INF",
940
968
  "DIAG_MASK_ZERO",
@@ -946,6 +974,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
946
974
  "CONV_TRANSPOSE_1D",
947
975
  "IM2COL",
948
976
  "IM2COL_BACK",
977
+ "IM2COL_3D",
978
+ "CONV_2D",
979
+ "CONV_3D",
949
980
  "CONV_2D_DW",
950
981
  "CONV_TRANSPOSE_2D",
951
982
  "POOL_1D",
@@ -954,6 +985,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
954
985
  "UPSCALE",
955
986
  "PAD",
956
987
  "PAD_REFLECT_1D",
988
+ "ROLL",
957
989
  "ARANGE",
958
990
  "TIMESTEP_EMBEDDING",
959
991
  "ARGSORT",
@@ -982,15 +1014,19 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
982
1014
  "CROSS_ENTROPY_LOSS",
983
1015
  "CROSS_ENTROPY_LOSS_BACK",
984
1016
  "OPT_STEP_ADAMW",
1017
+ "OPT_STEP_SGD",
1018
+
1019
+ "GLU",
985
1020
  };
986
1021
 
987
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1022
+ static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
988
1023
 
989
1024
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
990
1025
  "none",
991
1026
 
992
1027
  "x",
993
1028
  "x+y",
1029
+ "x[i]+y",
994
1030
  "x+y",
995
1031
  "view(x,nb,offset)+=y->x",
996
1032
  "x-y",
@@ -1030,6 +1066,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1030
1066
  "transpose(x)",
1031
1067
  "get_rows(x)",
1032
1068
  "get_rows_back(x)",
1069
+ "set_rows(x)",
1033
1070
  "diag(x)",
1034
1071
  "diag_mask_inf(x)",
1035
1072
  "diag_mask_zero(x)",
@@ -1041,6 +1078,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1041
1078
  "conv_transpose_1d(x)",
1042
1079
  "im2col(x)",
1043
1080
  "im2col_back(x)",
1081
+ "im2col_3d(x)",
1082
+ "conv_2d(x)",
1083
+ "conv_3d(x)",
1044
1084
  "conv_2d_dw(x)",
1045
1085
  "conv_transpose_2d(x)",
1046
1086
  "pool_1d(x)",
@@ -1049,6 +1089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1049
1089
  "upscale(x)",
1050
1090
  "pad(x)",
1051
1091
  "pad_reflect_1d(x)",
1092
+ "roll(x)",
1052
1093
  "arange(start, stop, step)",
1053
1094
  "timestep_embedding(timesteps, dim, max_period)",
1054
1095
  "argsort(x)",
@@ -1077,13 +1118,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1077
1118
  "cross_entropy_loss(x,y)",
1078
1119
  "cross_entropy_loss_back(x,y)",
1079
1120
  "adamw(x)",
1121
+ "sgd(x)",
1122
+
1123
+ "glu(x)",
1080
1124
  };
1081
1125
 
1082
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1126
+ static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
1083
1127
 
1084
1128
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1085
1129
 
1086
-
1087
1130
  static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1088
1131
  "ABS",
1089
1132
  "SGN",
@@ -1105,6 +1148,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1105
1148
  static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1106
1149
 
1107
1150
 
1151
+ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1152
+ "REGLU",
1153
+ "GEGLU",
1154
+ "SWIGLU",
1155
+ "SWIGLU_OAI",
1156
+ "GEGLU_ERF",
1157
+ "GEGLU_QUICK",
1158
+ };
1159
+
1160
+ static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
1161
+
1162
+
1108
1163
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1109
1164
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1110
1165
 
@@ -1207,11 +1262,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
1207
1262
  return GGML_UNARY_OP_NAME[op];
1208
1263
  }
1209
1264
 
1265
+ const char * ggml_glu_op_name(enum ggml_glu_op op) {
1266
+ return GGML_GLU_OP_NAME[op];
1267
+ }
1268
+
1210
1269
  const char * ggml_op_desc(const struct ggml_tensor * t) {
1211
1270
  if (t->op == GGML_OP_UNARY) {
1212
1271
  enum ggml_unary_op uop = ggml_get_unary_op(t);
1213
1272
  return ggml_unary_op_name(uop);
1214
1273
  }
1274
+ if (t->op == GGML_OP_GLU) {
1275
+ enum ggml_glu_op gop = ggml_get_glu_op(t);
1276
+ return ggml_glu_op_name(gop);
1277
+ }
1215
1278
  return ggml_op_name(t->op);
1216
1279
  }
1217
1280
 
@@ -1262,6 +1325,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
1262
1325
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
1263
1326
  case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
1264
1327
  case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
1328
+ case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
1265
1329
  case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
1266
1330
  case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
1267
1331
  case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
@@ -1348,6 +1412,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1348
1412
  tensor->nb[2] == ggml_type_size(tensor->type);
1349
1413
  }
1350
1414
 
1415
+ bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
1416
+ return
1417
+ tensor->ne[0] == ggml_blck_size(tensor->type) ||
1418
+ tensor->nb[0] == ggml_type_size(tensor->type);
1419
+ }
1420
+
1351
1421
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1352
1422
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1353
1423
 
@@ -1419,14 +1489,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
1419
1489
  // initialize time system (required on Windows)
1420
1490
  ggml_time_init();
1421
1491
 
1422
- for (int i = 0; i < (1 << 16); ++i) {
1423
- union {
1424
- uint16_t u16;
1425
- ggml_fp16_t fp16;
1426
- } u = {i};
1427
- ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
1428
- }
1429
-
1430
1492
  is_first_call = false;
1431
1493
  }
1432
1494
 
@@ -1730,6 +1792,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
1730
1792
  return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1731
1793
  }
1732
1794
 
1795
+ enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1796
+ GGML_ASSERT(tensor->op == GGML_OP_GLU);
1797
+ return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1798
+ }
1799
+
1733
1800
  const char * ggml_get_name(const struct ggml_tensor * tensor) {
1734
1801
  return tensor->name;
1735
1802
  }
@@ -1909,6 +1976,27 @@ struct ggml_tensor * ggml_add_cast(
1909
1976
  return ggml_add_cast_impl(ctx, a, b, type);
1910
1977
  }
1911
1978
 
1979
+ struct ggml_tensor * ggml_add_id(
1980
+ struct ggml_context * ctx,
1981
+ struct ggml_tensor * a,
1982
+ struct ggml_tensor * b,
1983
+ struct ggml_tensor * ids) {
1984
+
1985
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
1986
+ GGML_ASSERT(a->ne[1] == ids->ne[0]);
1987
+ GGML_ASSERT(a->ne[2] == ids->ne[1]);
1988
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
1989
+
1990
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
1991
+
1992
+ result->op = GGML_OP_ADD_ID;
1993
+ result->src[0] = a;
1994
+ result->src[1] = b;
1995
+ result->src[2] = ids;
1996
+
1997
+ return result;
1998
+ }
1999
+
1912
2000
  // ggml_add1
1913
2001
 
1914
2002
  static struct ggml_tensor * ggml_add1_impl(
@@ -2312,6 +2400,26 @@ struct ggml_tensor * ggml_repeat(
2312
2400
  return result;
2313
2401
  }
2314
2402
 
2403
+ struct ggml_tensor * ggml_repeat_4d(
2404
+ struct ggml_context * ctx,
2405
+ struct ggml_tensor * a,
2406
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
2407
+ const bool can_repeat = ggml_is_empty(a) || (
2408
+ (ne0 % a->ne[0] == 0) &&
2409
+ (ne1 % a->ne[1] == 0) &&
2410
+ (ne2 % a->ne[2] == 0) &&
2411
+ (ne3 % a->ne[3] == 0)
2412
+ );
2413
+ GGML_ASSERT(can_repeat);
2414
+
2415
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
2416
+
2417
+ result->op = GGML_OP_REPEAT;
2418
+ result->src[0] = a;
2419
+
2420
+ return result;
2421
+ }
2422
+
2315
2423
  // ggml_repeat_back
2316
2424
 
2317
2425
  struct ggml_tensor * ggml_repeat_back(
@@ -2589,6 +2697,169 @@ struct ggml_tensor * ggml_exp_inplace(
2589
2697
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2590
2698
  }
2591
2699
 
2700
+ // ggml_glu
2701
+
2702
+ static struct ggml_tensor * ggml_glu_impl(
2703
+ struct ggml_context * ctx,
2704
+ struct ggml_tensor * a,
2705
+ struct ggml_tensor * b,
2706
+ enum ggml_glu_op op,
2707
+ bool swapped) {
2708
+ GGML_ASSERT(ggml_is_contiguous_1(a));
2709
+
2710
+ if (b) {
2711
+ GGML_ASSERT(ggml_is_contiguous_1(b));
2712
+ GGML_ASSERT(ggml_are_same_shape(a, b));
2713
+ GGML_ASSERT(a->type == b->type);
2714
+ }
2715
+
2716
+ int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2717
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2718
+
2719
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
2720
+ ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2721
+
2722
+ result->op = GGML_OP_GLU;
2723
+ result->src[0] = a;
2724
+ result->src[1] = b;
2725
+
2726
+ return result;
2727
+ }
2728
+
2729
+ struct ggml_tensor * ggml_glu(
2730
+ struct ggml_context * ctx,
2731
+ struct ggml_tensor * a,
2732
+ enum ggml_glu_op op,
2733
+ bool swapped) {
2734
+ return ggml_glu_impl(ctx, a, NULL, op, swapped);
2735
+ }
2736
+
2737
+ struct ggml_tensor * ggml_glu_split(
2738
+ struct ggml_context * ctx,
2739
+ struct ggml_tensor * a,
2740
+ struct ggml_tensor * b,
2741
+ enum ggml_glu_op op) {
2742
+ return ggml_glu_impl(ctx, a, b, op, false);
2743
+ }
2744
+
2745
+ // ggml_reglu
2746
+
2747
+ struct ggml_tensor * ggml_reglu(
2748
+ struct ggml_context * ctx,
2749
+ struct ggml_tensor * a) {
2750
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
2751
+ }
2752
+
2753
+ struct ggml_tensor * ggml_reglu_swapped(
2754
+ struct ggml_context * ctx,
2755
+ struct ggml_tensor * a) {
2756
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
2757
+ }
2758
+
2759
+ struct ggml_tensor * ggml_reglu_split(
2760
+ struct ggml_context * ctx,
2761
+ struct ggml_tensor * a,
2762
+ struct ggml_tensor * b) {
2763
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
2764
+ }
2765
+
2766
+ // ggml_geglu
2767
+
2768
+ struct ggml_tensor * ggml_geglu(
2769
+ struct ggml_context * ctx,
2770
+ struct ggml_tensor * a) {
2771
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
2772
+ }
2773
+
2774
+ struct ggml_tensor * ggml_geglu_swapped(
2775
+ struct ggml_context * ctx,
2776
+ struct ggml_tensor * a) {
2777
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
2778
+ }
2779
+
2780
+ struct ggml_tensor * ggml_geglu_split(
2781
+ struct ggml_context * ctx,
2782
+ struct ggml_tensor * a,
2783
+ struct ggml_tensor * b) {
2784
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
2785
+ }
2786
+
2787
+ // ggml_swiglu
2788
+
2789
+ struct ggml_tensor * ggml_swiglu(
2790
+ struct ggml_context * ctx,
2791
+ struct ggml_tensor * a) {
2792
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
2793
+ }
2794
+
2795
+ struct ggml_tensor * ggml_swiglu_swapped(
2796
+ struct ggml_context * ctx,
2797
+ struct ggml_tensor * a) {
2798
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
2799
+ }
2800
+
2801
+ struct ggml_tensor * ggml_swiglu_split(
2802
+ struct ggml_context * ctx,
2803
+ struct ggml_tensor * a,
2804
+ struct ggml_tensor * b) {
2805
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2806
+ }
2807
+
2808
+ // ggml_geglu_erf
2809
+
2810
+ struct ggml_tensor * ggml_geglu_erf(
2811
+ struct ggml_context * ctx,
2812
+ struct ggml_tensor * a) {
2813
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
2814
+ }
2815
+
2816
+ struct ggml_tensor * ggml_geglu_erf_swapped(
2817
+ struct ggml_context * ctx,
2818
+ struct ggml_tensor * a) {
2819
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
2820
+ }
2821
+
2822
+ struct ggml_tensor * ggml_geglu_erf_split(
2823
+ struct ggml_context * ctx,
2824
+ struct ggml_tensor * a,
2825
+ struct ggml_tensor * b) {
2826
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
2827
+ }
2828
+
2829
+ // ggml_geglu_quick
2830
+
2831
+ struct ggml_tensor * ggml_geglu_quick(
2832
+ struct ggml_context * ctx,
2833
+ struct ggml_tensor * a) {
2834
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
2835
+ }
2836
+
2837
+ struct ggml_tensor * ggml_geglu_quick_swapped(
2838
+ struct ggml_context * ctx,
2839
+ struct ggml_tensor * a) {
2840
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
2841
+ }
2842
+
2843
+ struct ggml_tensor * ggml_geglu_quick_split(
2844
+ struct ggml_context * ctx,
2845
+ struct ggml_tensor * a,
2846
+ struct ggml_tensor * b) {
2847
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
2848
+ }
2849
+
2850
+ struct ggml_tensor * ggml_swiglu_oai(
2851
+ struct ggml_context * ctx,
2852
+ struct ggml_tensor * a,
2853
+ struct ggml_tensor * b,
2854
+ float alpha,
2855
+ float limit) {
2856
+ struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
2857
+ ggml_set_op_params_f32(result, 2, alpha);
2858
+ ggml_set_op_params_f32(result, 3, limit);
2859
+
2860
+ return result;
2861
+ }
2862
+
2592
2863
  // ggml_norm
2593
2864
 
2594
2865
  static struct ggml_tensor * ggml_norm_impl(
@@ -2846,12 +3117,14 @@ static struct ggml_tensor * ggml_scale_impl(
2846
3117
  struct ggml_context * ctx,
2847
3118
  struct ggml_tensor * a,
2848
3119
  float s,
3120
+ float b,
2849
3121
  bool inplace) {
2850
3122
  GGML_ASSERT(ggml_is_padded_1d(a));
2851
3123
 
2852
3124
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2853
3125
 
2854
- ggml_set_op_params(result, &s, sizeof(s));
3126
+ float params[2] = { s, b };
3127
+ ggml_set_op_params(result, &params, sizeof(params));
2855
3128
 
2856
3129
  result->op = GGML_OP_SCALE;
2857
3130
  result->src[0] = a;
@@ -2863,14 +3136,30 @@ struct ggml_tensor * ggml_scale(
2863
3136
  struct ggml_context * ctx,
2864
3137
  struct ggml_tensor * a,
2865
3138
  float s) {
2866
- return ggml_scale_impl(ctx, a, s, false);
3139
+ return ggml_scale_impl(ctx, a, s, 0.0, false);
2867
3140
  }
2868
3141
 
2869
3142
  struct ggml_tensor * ggml_scale_inplace(
2870
3143
  struct ggml_context * ctx,
2871
3144
  struct ggml_tensor * a,
2872
3145
  float s) {
2873
- return ggml_scale_impl(ctx, a, s, true);
3146
+ return ggml_scale_impl(ctx, a, s, 0.0, true);
3147
+ }
3148
+
3149
+ struct ggml_tensor * ggml_scale_bias(
3150
+ struct ggml_context * ctx,
3151
+ struct ggml_tensor * a,
3152
+ float s,
3153
+ float b) {
3154
+ return ggml_scale_impl(ctx, a, s, b, false);
3155
+ }
3156
+
3157
+ struct ggml_tensor * ggml_scale_bias_inplace(
3158
+ struct ggml_context * ctx,
3159
+ struct ggml_tensor * a,
3160
+ float s,
3161
+ float b) {
3162
+ return ggml_scale_impl(ctx, a, s, b, true);
2874
3163
  }
2875
3164
 
2876
3165
  // ggml_set
@@ -3334,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows(
3334
3623
  struct ggml_tensor * a,
3335
3624
  struct ggml_tensor * b) {
3336
3625
  GGML_ASSERT(a->ne[2] == b->ne[1]);
3626
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
3337
3627
  GGML_ASSERT(b->ne[3] == 1);
3338
3628
  GGML_ASSERT(b->type == GGML_TYPE_I32);
3339
3629
 
@@ -3372,6 +3662,36 @@ struct ggml_tensor * ggml_get_rows_back(
3372
3662
  return result;
3373
3663
  }
3374
3664
 
3665
+ // ggml_set_rows
3666
+
3667
+ struct ggml_tensor * ggml_set_rows(
3668
+ struct ggml_context * ctx,
3669
+ struct ggml_tensor * a,
3670
+ struct ggml_tensor * b,
3671
+ struct ggml_tensor * c) {
3672
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
3673
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
3674
+ GGML_ASSERT(a->ne[3] == b->ne[3]);
3675
+ GGML_ASSERT(b->ne[1] == c->ne[0]);
3676
+ GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3677
+ GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3678
+ GGML_ASSERT(c->ne[3] == 1);
3679
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
3680
+ GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
3681
+
3682
+ GGML_ASSERT(ggml_is_contiguous_rows(a));
3683
+ GGML_ASSERT(ggml_is_contiguous_rows(b));
3684
+
3685
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3686
+
3687
+ result->op = GGML_OP_SET_ROWS;
3688
+ result->src[0] = b;
3689
+ result->src[1] = c;
3690
+ result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
3691
+
3692
+ return result;
3693
+ }
3694
+
3375
3695
  // ggml_diag
3376
3696
 
3377
3697
  struct ggml_tensor * ggml_diag(
@@ -3466,9 +3786,10 @@ static struct ggml_tensor * ggml_soft_max_impl(
3466
3786
  if (mask) {
3467
3787
  GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
3468
3788
  GGML_ASSERT(ggml_is_contiguous(mask));
3469
- GGML_ASSERT(ggml_is_matrix(mask));
3470
3789
  GGML_ASSERT(mask->ne[0] == a->ne[0]);
3471
3790
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
3791
+ GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
3792
+ GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
3472
3793
  }
3473
3794
 
3474
3795
  if (max_bias > 0.0f) {
@@ -3508,6 +3829,22 @@ struct ggml_tensor * ggml_soft_max_ext(
3508
3829
  return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3509
3830
  }
3510
3831
 
3832
+ void ggml_soft_max_add_sinks(
3833
+ struct ggml_tensor * a,
3834
+ struct ggml_tensor * sinks) {
3835
+ if (!sinks) {
3836
+ a->src[2] = NULL;
3837
+ return;
3838
+ }
3839
+
3840
+ GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
3841
+ GGML_ASSERT(a->src[2] == NULL);
3842
+ GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
3843
+ GGML_ASSERT(sinks->type == GGML_TYPE_F32);
3844
+
3845
+ a->src[2] = sinks;
3846
+ }
3847
+
3511
3848
  // ggml_soft_max_ext_back
3512
3849
 
3513
3850
  static struct ggml_tensor * ggml_soft_max_ext_back_impl(
@@ -3555,6 +3892,7 @@ static struct ggml_tensor * ggml_rope_impl(
3555
3892
  struct ggml_tensor * b,
3556
3893
  struct ggml_tensor * c,
3557
3894
  int n_dims,
3895
+ int sections[GGML_MROPE_SECTIONS],
3558
3896
  int mode,
3559
3897
  int n_ctx_orig,
3560
3898
  float freq_base,
@@ -3568,15 +3906,19 @@ static struct ggml_tensor * ggml_rope_impl(
3568
3906
 
3569
3907
  GGML_ASSERT(ggml_is_vector(b));
3570
3908
  GGML_ASSERT(b->type == GGML_TYPE_I32);
3571
- GGML_ASSERT(a->ne[2] == b->ne[0]);
3909
+
3910
+ bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
3911
+ if (mrope_used) {
3912
+ GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
3913
+ } else {
3914
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
3915
+ }
3572
3916
 
3573
3917
  if (c) {
3574
3918
  GGML_ASSERT(c->type == GGML_TYPE_F32);
3575
3919
  GGML_ASSERT(c->ne[0] >= n_dims / 2);
3576
3920
  }
3577
3921
 
3578
- int sections[4] = {0, 0, 0, 0};
3579
-
3580
3922
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3581
3923
 
3582
3924
  int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -3586,7 +3928,11 @@ static struct ggml_tensor * ggml_rope_impl(
3586
3928
  memcpy(params + 8, &attn_factor, sizeof(float));
3587
3929
  memcpy(params + 9, &beta_fast, sizeof(float));
3588
3930
  memcpy(params + 10, &beta_slow, sizeof(float));
3589
- memcpy(params + 11, &sections, sizeof(int)*4);
3931
+ if (mrope_used && sections) {
3932
+ memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
3933
+ } else {
3934
+ memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
3935
+ }
3590
3936
  ggml_set_op_params(result, params, sizeof(params));
3591
3937
 
3592
3938
  result->op = GGML_OP_ROPE;
@@ -3604,7 +3950,7 @@ struct ggml_tensor * ggml_rope(
3604
3950
  int n_dims,
3605
3951
  int mode) {
3606
3952
  return ggml_rope_impl(
3607
- ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
3953
+ ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
3608
3954
  );
3609
3955
  }
3610
3956
 
@@ -3614,7 +3960,7 @@ struct ggml_tensor * ggml_rope_multi(
3614
3960
  struct ggml_tensor * b,
3615
3961
  struct ggml_tensor * c,
3616
3962
  int n_dims,
3617
- int sections[4],
3963
+ int sections[GGML_MROPE_SECTIONS],
3618
3964
  int mode,
3619
3965
  int n_ctx_orig,
3620
3966
  float freq_base,
@@ -3623,36 +3969,31 @@ struct ggml_tensor * ggml_rope_multi(
3623
3969
  float attn_factor,
3624
3970
  float beta_fast,
3625
3971
  float beta_slow) {
3626
- // Multimodal Rotary Position Embedding
3627
- GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
3628
-
3629
- GGML_ASSERT(ggml_is_vector(b));
3630
- GGML_ASSERT(b->type == GGML_TYPE_I32);
3631
- GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
3632
-
3633
- if (c) {
3634
- GGML_ASSERT(c->type == GGML_TYPE_F32);
3635
- GGML_ASSERT(c->ne[0] >= n_dims / 2);
3636
- }
3637
-
3638
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3639
-
3640
- int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3641
- memcpy(params + 5, &freq_base, sizeof(float));
3642
- memcpy(params + 6, &freq_scale, sizeof(float));
3643
- memcpy(params + 7, &ext_factor, sizeof(float));
3644
- memcpy(params + 8, &attn_factor, sizeof(float));
3645
- memcpy(params + 9, &beta_fast, sizeof(float));
3646
- memcpy(params + 10, &beta_slow, sizeof(float));
3647
- memcpy(&params[11], sections, sizeof(int)*4);
3648
- ggml_set_op_params(result, params, sizeof(params));
3649
-
3650
- result->op = GGML_OP_ROPE;
3651
- result->src[0] = a;
3652
- result->src[1] = b;
3653
- result->src[2] = c;
3972
+ return ggml_rope_impl(
3973
+ ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
3974
+ ext_factor, attn_factor, beta_fast, beta_slow, false
3975
+ );
3976
+ }
3654
3977
 
3655
- return result;
3978
+ struct ggml_tensor * ggml_rope_multi_inplace(
3979
+ struct ggml_context * ctx,
3980
+ struct ggml_tensor * a,
3981
+ struct ggml_tensor * b,
3982
+ struct ggml_tensor * c,
3983
+ int n_dims,
3984
+ int sections[GGML_MROPE_SECTIONS],
3985
+ int mode,
3986
+ int n_ctx_orig,
3987
+ float freq_base,
3988
+ float freq_scale,
3989
+ float ext_factor,
3990
+ float attn_factor,
3991
+ float beta_fast,
3992
+ float beta_slow) {
3993
+ return ggml_rope_impl(
3994
+ ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
3995
+ ext_factor, attn_factor, beta_fast, beta_slow, true
3996
+ );
3656
3997
  }
3657
3998
 
3658
3999
  struct ggml_tensor * ggml_rope_inplace(
@@ -3662,7 +4003,7 @@ struct ggml_tensor * ggml_rope_inplace(
3662
4003
  int n_dims,
3663
4004
  int mode) {
3664
4005
  return ggml_rope_impl(
3665
- ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
4006
+ ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
3666
4007
  );
3667
4008
  }
3668
4009
 
@@ -3681,7 +4022,7 @@ struct ggml_tensor * ggml_rope_ext(
3681
4022
  float beta_fast,
3682
4023
  float beta_slow) {
3683
4024
  return ggml_rope_impl(
3684
- ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4025
+ ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
3685
4026
  ext_factor, attn_factor, beta_fast, beta_slow, false
3686
4027
  );
3687
4028
  }
@@ -3701,7 +4042,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
3701
4042
  float beta_fast,
3702
4043
  float beta_slow) {
3703
4044
  return ggml_rope_impl(
3704
- ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4045
+ ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
3705
4046
  ext_factor, attn_factor, beta_fast, beta_slow, true
3706
4047
  );
3707
4048
  }
@@ -3720,7 +4061,7 @@ struct ggml_tensor * ggml_rope_custom(
3720
4061
  float beta_fast,
3721
4062
  float beta_slow) {
3722
4063
  return ggml_rope_impl(
3723
- ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4064
+ ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
3724
4065
  ext_factor, attn_factor, beta_fast, beta_slow, false
3725
4066
  );
3726
4067
  }
@@ -3739,7 +4080,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
3739
4080
  float beta_fast,
3740
4081
  float beta_slow) {
3741
4082
  return ggml_rope_impl(
3742
- ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4083
+ ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
3743
4084
  ext_factor, attn_factor, beta_fast, beta_slow, true
3744
4085
  );
3745
4086
  }
@@ -3937,14 +4278,13 @@ struct ggml_tensor * ggml_conv_1d_dw(
3937
4278
  int s0,
3938
4279
  int p0,
3939
4280
  int d0) {
3940
- struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
3941
4281
  struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
3942
4282
 
3943
- struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
4283
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
3944
4284
 
3945
4285
  struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
3946
4286
 
3947
- result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
4287
+ result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
3948
4288
 
3949
4289
  return result;
3950
4290
  }
@@ -4025,6 +4365,91 @@ struct ggml_tensor * ggml_conv_2d(
4025
4365
  return result;
4026
4366
  }
4027
4367
 
4368
+ // a: [OC*IC, KD, KH, KW]
4369
+ // b: [N*IC, ID, IH, IW]
4370
+ // result: [N*OD, OH, OW, IC * KD * KH * KW]
4371
+ struct ggml_tensor * ggml_im2col_3d(
4372
+ struct ggml_context * ctx,
4373
+ struct ggml_tensor * a,
4374
+ struct ggml_tensor * b,
4375
+ int64_t IC,
4376
+ int s0, // stride width
4377
+ int s1, // stride height
4378
+ int s2, // stride depth
4379
+ int p0, // padding width
4380
+ int p1, // padding height
4381
+ int p2, // padding depth
4382
+ int d0, // dilation width
4383
+ int d1, // dilation height
4384
+ int d2, // dilation depth
4385
+ enum ggml_type dst_type) {
4386
+ const int64_t N = b->ne[3] / IC;
4387
+ const int64_t ID = b->ne[2];
4388
+ const int64_t IH = b->ne[1];
4389
+ const int64_t IW = b->ne[0];
4390
+
4391
+ const int64_t OC = a->ne[3] / IC;
4392
+ UNUSED(OC);
4393
+ const int64_t KD = a->ne[2];
4394
+ const int64_t KH = a->ne[1];
4395
+ const int64_t KW = a->ne[0];
4396
+ const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
4397
+ const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
4398
+ const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
4399
+
4400
+ GGML_ASSERT((OD > 0) && "b too small compared to a");
4401
+ GGML_ASSERT((OH > 0) && "b too small compared to a");
4402
+ GGML_ASSERT((OW > 0) && "b too small compared to a");
4403
+
4404
+
4405
+ const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
4406
+
4407
+ struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
4408
+ int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
4409
+ ggml_set_op_params(result, params, sizeof(params));
4410
+
4411
+ result->op = GGML_OP_IM2COL_3D;
4412
+ result->src[0] = a;
4413
+ result->src[1] = b;
4414
+
4415
+ return result;
4416
+ }
4417
+
4418
+ // a: [OC*IC, KD, KH, KW]
4419
+ // b: [N*IC, ID, IH, IW]
4420
+ // result: [N*OC, OD, OH, OW]
4421
+ struct ggml_tensor * ggml_conv_3d(
4422
+ struct ggml_context * ctx,
4423
+ struct ggml_tensor * a,
4424
+ struct ggml_tensor * b,
4425
+ int64_t IC,
4426
+ int s0, // stride width
4427
+ int s1, // stride height
4428
+ int s2, // stride depth
4429
+ int p0, // padding width
4430
+ int p1, // padding height
4431
+ int p2, // padding depth
4432
+ int d0, // dilation width
4433
+ int d1, // dilation height
4434
+ int d2 // dilation depth
4435
+ ) {
4436
+ struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
4437
+
4438
+ int64_t OC = a->ne[3] / IC;
4439
+ int64_t N = b->ne[3] / IC;
4440
+ struct ggml_tensor * result =
4441
+ ggml_mul_mat(ctx,
4442
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
4443
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
4444
+
4445
+ int64_t OD = im2col->ne[3] / N;
4446
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
4447
+ result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
4448
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
4449
+
4450
+ return result;
4451
+ }
4452
+
4028
4453
  // ggml_conv_2d_sk_p0
4029
4454
 
4030
4455
  struct ggml_tensor * ggml_conv_2d_sk_p0(
@@ -4108,6 +4533,94 @@ struct ggml_tensor * ggml_conv_2d_dw_direct(
4108
4533
  return result;
4109
4534
  }
4110
4535
 
4536
+ // ggml_conv_2d_direct
4537
+
4538
+ struct ggml_tensor * ggml_conv_2d_direct(
4539
+ struct ggml_context * ctx,
4540
+ struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4541
+ struct ggml_tensor * b, // input data [W, H, C, N]
4542
+ int s0, // stride dimension 0
4543
+ int s1, // stride dimension 1
4544
+ int p0, // padding dimension 0
4545
+ int p1, // padding dimension 1
4546
+ int d0, // dilation dimension 0
4547
+ int d1) {// dilation dimension 1
4548
+
4549
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
4550
+ //GGML_ASSERT(a->type == b->type);
4551
+
4552
+ int64_t ne[4];
4553
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4554
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4555
+ ne[2] = a->ne[3];
4556
+ ne[3] = b->ne[3];
4557
+
4558
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4559
+
4560
+ ggml_set_op_params_i32(result, 0, s0);
4561
+ ggml_set_op_params_i32(result, 1, s1);
4562
+ ggml_set_op_params_i32(result, 2, p0);
4563
+ ggml_set_op_params_i32(result, 3, p1);
4564
+ ggml_set_op_params_i32(result, 4, d0);
4565
+ ggml_set_op_params_i32(result, 5, d1);
4566
+
4567
+ result->op = GGML_OP_CONV_2D;
4568
+ result->src[0] = a;
4569
+ result->src[1] = b;
4570
+
4571
+ return result;
4572
+ }
4573
+
4574
+ // ggml_conv_3d_direct
4575
+
4576
+ struct ggml_tensor * ggml_conv_3d_direct(
4577
+ struct ggml_context * ctx,
4578
+ struct ggml_tensor * a,
4579
+ struct ggml_tensor * b,
4580
+ int s0,
4581
+ int s1,
4582
+ int s2,
4583
+ int p0,
4584
+ int p1,
4585
+ int p2,
4586
+ int d0,
4587
+ int d1,
4588
+ int d2,
4589
+ int c,
4590
+ int n,
4591
+ int oc) {
4592
+
4593
+ GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
4594
+ GGML_ASSERT(b->ne[3] == (int64_t) c * n);
4595
+
4596
+ int64_t ne[4];
4597
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4598
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4599
+ ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
4600
+ ne[3] = (int64_t) oc * n;
4601
+
4602
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4603
+
4604
+ ggml_set_op_params_i32(result, 0, s0);
4605
+ ggml_set_op_params_i32(result, 1, s1);
4606
+ ggml_set_op_params_i32(result, 2, s2);
4607
+ ggml_set_op_params_i32(result, 3, p0);
4608
+ ggml_set_op_params_i32(result, 4, p1);
4609
+ ggml_set_op_params_i32(result, 5, p2);
4610
+ ggml_set_op_params_i32(result, 6, d0);
4611
+ ggml_set_op_params_i32(result, 7, d1);
4612
+ ggml_set_op_params_i32(result, 8, d2);
4613
+ ggml_set_op_params_i32(result, 9, c);
4614
+ ggml_set_op_params_i32(result, 10, n);
4615
+ ggml_set_op_params_i32(result, 11, oc);
4616
+
4617
+ result->op = GGML_OP_CONV_3D;
4618
+ result->src[0] = a;
4619
+ result->src[1] = b;
4620
+
4621
+ return result;
4622
+ }
4623
+
4111
4624
  // ggml_conv_transpose_2d_p0
4112
4625
 
4113
4626
  static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4224,24 +4737,21 @@ struct ggml_tensor * ggml_pool_2d_back(
4224
4737
  return result;
4225
4738
  }
4226
4739
 
4227
- // ggml_upscale
4740
+ // ggml_upscale / ggml_interpolate
4228
4741
 
4229
- static struct ggml_tensor * ggml_upscale_impl(
4742
+ static struct ggml_tensor * ggml_interpolate_impl(
4230
4743
  struct ggml_context * ctx,
4231
4744
  struct ggml_tensor * a,
4232
- int ne0,
4233
- int ne1,
4234
- int ne2,
4235
- int ne3,
4236
- enum ggml_scale_mode mode) {
4237
- GGML_ASSERT(a->ne[0] <= ne0);
4238
- GGML_ASSERT(a->ne[1] <= ne1);
4239
- GGML_ASSERT(a->ne[2] <= ne2);
4240
- GGML_ASSERT(a->ne[3] <= ne3);
4745
+ int64_t ne0,
4746
+ int64_t ne1,
4747
+ int64_t ne2,
4748
+ int64_t ne3,
4749
+ uint32_t mode) {
4750
+ GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
4241
4751
 
4242
4752
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4243
4753
 
4244
- ggml_set_op_params_i32(result, 0, mode);
4754
+ ggml_set_op_params_i32(result, 0, (int32_t)mode);
4245
4755
 
4246
4756
  result->op = GGML_OP_UPSCALE;
4247
4757
  result->src[0] = a;
@@ -4254,7 +4764,8 @@ struct ggml_tensor * ggml_upscale(
4254
4764
  struct ggml_tensor * a,
4255
4765
  int scale_factor,
4256
4766
  enum ggml_scale_mode mode) {
4257
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4767
+ GGML_ASSERT(scale_factor > 1);
4768
+ return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4258
4769
  }
4259
4770
 
4260
4771
  struct ggml_tensor * ggml_upscale_ext(
@@ -4265,7 +4776,18 @@ struct ggml_tensor * ggml_upscale_ext(
4265
4776
  int ne2,
4266
4777
  int ne3,
4267
4778
  enum ggml_scale_mode mode) {
4268
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4779
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4780
+ }
4781
+
4782
+ struct ggml_tensor * ggml_interpolate(
4783
+ struct ggml_context * ctx,
4784
+ struct ggml_tensor * a,
4785
+ int64_t ne0,
4786
+ int64_t ne1,
4787
+ int64_t ne2,
4788
+ int64_t ne3,
4789
+ uint32_t mode) {
4790
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4269
4791
  }
4270
4792
 
4271
4793
  // ggml_pad
@@ -4277,11 +4799,36 @@ struct ggml_tensor * ggml_pad(
4277
4799
  int p1,
4278
4800
  int p2,
4279
4801
  int p3) {
4802
+ return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
4803
+ }
4804
+
4805
+ struct ggml_tensor * ggml_pad_ext(
4806
+ struct ggml_context * ctx,
4807
+ struct ggml_tensor * a,
4808
+ int lp0,
4809
+ int rp0,
4810
+ int lp1,
4811
+ int rp1,
4812
+ int lp2,
4813
+ int rp2,
4814
+ int lp3,
4815
+ int rp3
4816
+ ) {
4280
4817
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
4281
- a->ne[0] + p0,
4282
- a->ne[1] + p1,
4283
- a->ne[2] + p2,
4284
- a->ne[3] + p3);
4818
+ a->ne[0] + lp0 + rp0,
4819
+ a->ne[1] + lp1 + rp1,
4820
+ a->ne[2] + lp2 + rp2,
4821
+ a->ne[3] + lp3 + rp3);
4822
+
4823
+ ggml_set_op_params_i32(result, 0, lp0);
4824
+ ggml_set_op_params_i32(result, 1, rp0);
4825
+ ggml_set_op_params_i32(result, 2, lp1);
4826
+ ggml_set_op_params_i32(result, 3, rp1);
4827
+ ggml_set_op_params_i32(result, 4, lp2);
4828
+ ggml_set_op_params_i32(result, 5, rp2);
4829
+ ggml_set_op_params_i32(result, 6, lp3);
4830
+ ggml_set_op_params_i32(result, 7, rp3);
4831
+
4285
4832
 
4286
4833
  result->op = GGML_OP_PAD;
4287
4834
  result->src[0] = a;
@@ -4320,6 +4867,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
4320
4867
  return result;
4321
4868
  }
4322
4869
 
4870
+ // ggml_roll
4871
+
4872
+ struct ggml_tensor * ggml_roll(
4873
+ struct ggml_context * ctx,
4874
+ struct ggml_tensor * a,
4875
+ int shift0,
4876
+ int shift1,
4877
+ int shift2,
4878
+ int shift3) {
4879
+ GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
4880
+ GGML_ASSERT(abs(shift0) < a->ne[0]);
4881
+ GGML_ASSERT(abs(shift1) < a->ne[1]);
4882
+ GGML_ASSERT(abs(shift2) < a->ne[2]);
4883
+ GGML_ASSERT(abs(shift3) < a->ne[3]);
4884
+
4885
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
4886
+
4887
+ ggml_set_op_params_i32(result, 0, shift0);
4888
+ ggml_set_op_params_i32(result, 1, shift1);
4889
+ ggml_set_op_params_i32(result, 2, shift2);
4890
+ ggml_set_op_params_i32(result, 3, shift3);
4891
+
4892
+ result->op = GGML_OP_ROLL;
4893
+ result->src[0] = a;
4894
+
4895
+ return result;
4896
+ }
4897
+
4323
4898
  // ggml_arange
4324
4899
 
4325
4900
  struct ggml_tensor * ggml_arange(
@@ -4349,12 +4924,8 @@ struct ggml_tensor * ggml_timestep_embedding(
4349
4924
  struct ggml_tensor * timesteps,
4350
4925
  int dim,
4351
4926
  int max_period) {
4352
- int actual_dim = dim;
4353
- if (dim % 2 != 0) {
4354
- actual_dim = dim + 1;
4355
- }
4356
4927
 
4357
- struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
4928
+ struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
4358
4929
 
4359
4930
  ggml_set_op_params_i32(result, 0, dim);
4360
4931
  ggml_set_op_params_i32(result, 1, max_period);
@@ -4414,13 +4985,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
4414
4985
  GGML_ASSERT(ggml_can_mul_mat(k, q));
4415
4986
  // TODO: check if vT can be multiplied by (k*qT)
4416
4987
 
4988
+ GGML_ASSERT(q->ne[3] == k->ne[3]);
4989
+ GGML_ASSERT(q->ne[3] == v->ne[3]);
4990
+
4417
4991
  if (mask) {
4418
4992
  GGML_ASSERT(ggml_is_contiguous(mask));
4419
- GGML_ASSERT(mask->ne[2] == 1);
4420
- GGML_ASSERT(mask->ne[3] == 1);
4421
4993
  GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
4422
4994
  "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
4423
4995
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
4996
+
4997
+ GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
4998
+ GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
4424
4999
  }
4425
5000
 
4426
5001
  if (max_bias > 0.0f) {
@@ -4462,6 +5037,22 @@ enum ggml_prec ggml_flash_attn_ext_get_prec(
4462
5037
  return (enum ggml_prec) prec_i32;
4463
5038
  }
4464
5039
 
5040
+ void ggml_flash_attn_ext_add_sinks(
5041
+ struct ggml_tensor * a,
5042
+ struct ggml_tensor * sinks) {
5043
+ if (!sinks) {
5044
+ a->src[4] = NULL;
5045
+ return;
5046
+ }
5047
+
5048
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
5049
+ GGML_ASSERT(a->src[4] == NULL);
5050
+ GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
5051
+ GGML_ASSERT(sinks->type == GGML_TYPE_F32);
5052
+
5053
+ a->src[4] = sinks;
5054
+ }
5055
+
4465
5056
  // ggml_flash_attn_back
4466
5057
 
4467
5058
  struct ggml_tensor * ggml_flash_attn_back(
@@ -4548,7 +5139,6 @@ struct ggml_tensor * ggml_ssm_conv(
4548
5139
  const int64_t n_s = sx->ne[2];
4549
5140
 
4550
5141
  // TODO: maybe support other strides than 1?
4551
- // FIXME: this is always true?
4552
5142
  GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
4553
5143
  GGML_ASSERT(sx->ne[1] == d_inner);
4554
5144
  GGML_ASSERT(n_t >= 0);
@@ -4571,36 +5161,49 @@ struct ggml_tensor * ggml_ssm_scan(
4571
5161
  struct ggml_tensor * dt,
4572
5162
  struct ggml_tensor * A,
4573
5163
  struct ggml_tensor * B,
4574
- struct ggml_tensor * C) {
5164
+ struct ggml_tensor * C,
5165
+ struct ggml_tensor * ids) {
4575
5166
  GGML_ASSERT(ggml_is_contiguous(s));
4576
- GGML_ASSERT(ggml_is_contiguous(x));
4577
5167
  GGML_ASSERT(ggml_is_contiguous(dt));
4578
5168
  GGML_ASSERT(ggml_is_contiguous(A));
4579
- GGML_ASSERT(ggml_is_matrix(A));
4580
- GGML_ASSERT(ggml_is_3d(B));
4581
- GGML_ASSERT(ggml_is_3d(s));
5169
+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
4582
5170
  GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
4583
5171
  GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
4584
- GGML_ASSERT(ggml_are_same_shape(x, dt));
5172
+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
5173
+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
5174
+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
4585
5175
  GGML_ASSERT(ggml_are_same_shape(B, C));
5176
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
4586
5177
 
4587
5178
  {
4588
5179
  const int64_t d_state = s->ne[0];
4589
- const int64_t d_inner = s->ne[1];
4590
- const int64_t n_seq_tokens = x->ne[1];
4591
- const int64_t n_seqs = x->ne[2];
4592
-
4593
- GGML_ASSERT(s->ne[2] == n_seqs);
4594
- GGML_ASSERT(x->ne[0] == d_inner);
4595
- GGML_ASSERT(A->ne[0] == d_state);
4596
- GGML_ASSERT(A->ne[1] == d_inner);
5180
+ const int64_t head_dim = x->ne[0];
5181
+ const int64_t n_head = x->ne[1];
5182
+ const int64_t n_seq_tokens = x->ne[2];
5183
+ const int64_t n_seqs = x->ne[3];
5184
+
5185
+ GGML_ASSERT(dt->ne[0] == n_head);
5186
+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
5187
+ GGML_ASSERT(dt->ne[2] == n_seqs);
5188
+ GGML_ASSERT(ggml_is_3d(dt));
5189
+ GGML_ASSERT(s->ne[1] == head_dim);
5190
+ GGML_ASSERT(s->ne[2] == n_head);
4597
5191
  GGML_ASSERT(B->ne[0] == d_state);
4598
- GGML_ASSERT(B->ne[1] == n_seq_tokens);
4599
- GGML_ASSERT(B->ne[2] == n_seqs);
5192
+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
5193
+ GGML_ASSERT(B->ne[3] == n_seqs);
5194
+ GGML_ASSERT(ids->ne[0] == n_seqs);
5195
+ GGML_ASSERT(ggml_is_vector(ids));
5196
+ GGML_ASSERT(A->ne[1] == n_head);
5197
+ GGML_ASSERT(ggml_is_matrix(A));
5198
+
5199
+ if (A->ne[0] != 1) {
5200
+ // Mamba-1 has more granular decay factors
5201
+ GGML_ASSERT(A->ne[0] == d_state);
5202
+ }
4600
5203
  }
4601
5204
 
4602
5205
  // concatenated y + ssm_states
4603
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
5206
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
4604
5207
 
4605
5208
  result->op = GGML_OP_SSM_SCAN;
4606
5209
  result->src[0] = s;
@@ -4609,6 +5212,7 @@ struct ggml_tensor * ggml_ssm_scan(
4609
5212
  result->src[3] = A;
4610
5213
  result->src[4] = B;
4611
5214
  result->src[5] = C;
5215
+ result->src[6] = ids;
4612
5216
 
4613
5217
  return result;
4614
5218
  }
@@ -5164,6 +5768,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
5164
5768
  return result;
5165
5769
  }
5166
5770
 
5771
+ // opt_step_sgd
5772
+
5773
+ struct ggml_tensor * ggml_opt_step_sgd(
5774
+ struct ggml_context * ctx,
5775
+ struct ggml_tensor * a,
5776
+ struct ggml_tensor * grad,
5777
+ struct ggml_tensor * params) {
5778
+ GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
5779
+ GGML_ASSERT(ggml_are_same_shape(a, grad));
5780
+ GGML_ASSERT(params->type == GGML_TYPE_F32);
5781
+ GGML_ASSERT(ggml_nelements(params) == 2);
5782
+
5783
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5784
+
5785
+ result->op = GGML_OP_OPT_STEP_SGD;
5786
+ result->src[0] = a;
5787
+ result->src[1] = grad;
5788
+ result->src[2] = params;
5789
+
5790
+ return result;
5791
+ }
5792
+
5167
5793
  ////////////////////////////////////////////////////////////////////////////////
5168
5794
 
5169
5795
  struct ggml_hash_set ggml_hash_set_new(size_t size) {
@@ -5432,7 +6058,7 @@ static void ggml_compute_backward(
5432
6058
  } break;
5433
6059
  case GGML_OP_MEAN: {
5434
6060
  if (src0_needs_grads) {
5435
- ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
6061
+ ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
5436
6062
  }
5437
6063
  } break;
5438
6064
  case GGML_OP_REPEAT: {
@@ -5509,7 +6135,7 @@ static void ggml_compute_backward(
5509
6135
  if (src0_needs_grads) {
5510
6136
  float s;
5511
6137
  memcpy(&s, tensor->op_params, sizeof(float));
5512
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
6138
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
5513
6139
  }
5514
6140
  } break;
5515
6141
  case GGML_OP_SET: {
@@ -5749,13 +6375,28 @@ static void ggml_compute_backward(
5749
6375
  }
5750
6376
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5751
6377
  } break;
6378
+ case GGML_OP_GLU: {
6379
+ switch (ggml_get_glu_op(tensor)) {
6380
+ case GGML_GLU_OP_SWIGLU: {
6381
+ if (src0_needs_grads) {
6382
+ GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6383
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6384
+ }
6385
+ if (src1_needs_grads) {
6386
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6387
+ }
6388
+ } break;
6389
+ default: {
6390
+ GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6391
+ } //break;
6392
+ }
6393
+ } break;
5752
6394
  case GGML_OP_NONE: {
5753
6395
  // noop
5754
6396
  } break;
5755
6397
  case GGML_OP_COUNT:
5756
6398
  default: {
5757
- fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
5758
- GGML_ABORT("fatal error");
6399
+ GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
5759
6400
  } //break;
5760
6401
  }
5761
6402
 
@@ -5764,19 +6405,32 @@ static void ggml_compute_backward(
5764
6405
  GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5765
6406
  }
5766
6407
 
5767
- static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
6408
+ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5768
6409
  // check if already visited
5769
- if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5770
- return;
6410
+ size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
6411
+ GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
6412
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
6413
+ // This is the first time we see this node in the current graph.
6414
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
6415
+ ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
6416
+ cgraph->use_counts[node_hash_pos] = 0;
6417
+ } else {
6418
+ // already visited
6419
+ return node_hash_pos;
5771
6420
  }
5772
6421
 
5773
6422
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
5774
6423
  const int k =
5775
6424
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5776
6425
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5777
- /* unknown order, just fall back to using i*/ i;
5778
- if (node->src[k]) {
5779
- ggml_visit_parents(cgraph, node->src[k]);
6426
+ /* unknown order, just fall back to using i */ i;
6427
+
6428
+ struct ggml_tensor * src = node->src[k];
6429
+ if (src) {
6430
+ size_t src_hash_pos = ggml_visit_parents(cgraph, src);
6431
+
6432
+ // Update the use count for this operand.
6433
+ cgraph->use_counts[src_hash_pos]++;
5780
6434
  }
5781
6435
  }
5782
6436
 
@@ -5800,6 +6454,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
5800
6454
  cgraph->nodes[cgraph->n_nodes] = node;
5801
6455
  cgraph->n_nodes++;
5802
6456
  }
6457
+
6458
+ return node_hash_pos;
5803
6459
  }
5804
6460
 
5805
6461
  static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -5937,6 +6593,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
5937
6593
  incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5938
6594
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
5939
6595
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6596
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
5940
6597
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
5941
6598
  if (grads) {
5942
6599
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -5966,11 +6623,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5966
6623
 
5967
6624
  void * p = cgraph + 1;
5968
6625
 
5969
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5970
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5971
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5972
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5973
- struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6626
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6627
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6628
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6629
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6630
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6631
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5974
6632
 
5975
6633
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
5976
6634
 
@@ -5985,6 +6643,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5985
6643
  /*.grads =*/ grads_ptr,
5986
6644
  /*.grad_accs =*/ grad_accs_ptr,
5987
6645
  /*.leafs =*/ leafs_ptr,
6646
+ /*.use_counts =*/ use_counts_ptr,
5988
6647
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
5989
6648
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
5990
6649
  };
@@ -6011,7 +6670,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
6011
6670
  /*.grads =*/ NULL, // gradients would need visited_hash_set
6012
6671
  /*.grad_accs =*/ NULL,
6013
6672
  /*.leafs =*/ NULL,
6014
- /*.visited_hash_set =*/ { 0, NULL, NULL },
6673
+ /*.use_counts =*/ cgraph0->use_counts,
6674
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
6015
6675
  /*.order =*/ cgraph0->order,
6016
6676
  };
6017
6677
 
@@ -6038,7 +6698,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6038
6698
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6039
6699
  // copy all hashset keys (tensors) that are in use
6040
6700
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6041
- ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6701
+ size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6702
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
6042
6703
  }
6043
6704
  }
6044
6705
 
@@ -6242,20 +6903,18 @@ static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgr
6242
6903
  static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
6243
6904
  struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
6244
6905
  struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
6245
- fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
6906
+ fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
6246
6907
  gparent0 ? (void *) gparent0 : (void *) parent,
6247
- gparent0 ? "g" : "x",
6248
6908
  gparent ? (void *) gparent : (void *) node,
6249
- gparent ? "g" : "x",
6250
6909
  gparent ? "empty" : "vee",
6251
6910
  gparent ? "dashed" : "solid",
6252
6911
  label);
6253
6912
  }
6254
6913
 
6255
6914
  static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
6256
- fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n",
6257
- (void *) parent, "x",
6258
- (void *) node, "x",
6915
+ fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
6916
+ (void *) parent,
6917
+ (void *) node,
6259
6918
  label);
6260
6919
  }
6261
6920
 
@@ -6476,6 +7135,7 @@ size_t ggml_quantize_chunk(
6476
7135
  case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6477
7136
  case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6478
7137
  case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7138
+ case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6479
7139
  case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6480
7140
  case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6481
7141
  case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;