whispercpp 1.3.1 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (857) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +7 -3
  3. data/README.md +161 -43
  4. data/Rakefile +45 -13
  5. data/ext/.gitignore +4 -8
  6. data/ext/dependencies.rb +73 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +85 -0
  9. data/ext/ruby_whisper.c +177 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +672 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1303 -0
  15. data/ext/ruby_whisper_segment.c +220 -0
  16. data/ext/ruby_whisper_transcribe.cpp +93 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  19. data/ext/sources/CMakeLists.txt +255 -0
  20. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  21. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  22. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  23. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  24. data/ext/sources/bindings/javascript/package.json +26 -0
  25. data/ext/sources/bindings/javascript/whisper.js +19 -0
  26. data/ext/sources/build-xcframework.sh +547 -0
  27. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  28. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  29. data/ext/sources/cmake/build-info.cmake +60 -0
  30. data/ext/sources/cmake/git-vars.cmake +22 -0
  31. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  32. data/ext/sources/cmake/whisper.pc.in +10 -0
  33. data/ext/sources/examples/CMakeLists.txt +124 -0
  34. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  35. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +133 -0
  36. data/ext/sources/examples/addon.node/addon.cpp +557 -0
  37. data/ext/sources/examples/addon.node/index.js +57 -0
  38. data/ext/sources/examples/addon.node/package.json +16 -0
  39. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  40. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  41. data/ext/sources/examples/bench/bench.cpp +176 -0
  42. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  43. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  44. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  45. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  46. data/ext/sources/examples/cli/cli.cpp +1295 -0
  47. data/ext/sources/examples/coi-serviceworker.js +146 -0
  48. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  49. data/ext/sources/examples/command/command.cpp +800 -0
  50. data/ext/sources/examples/command/commands.txt +9 -0
  51. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  52. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  53. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  54. data/ext/sources/examples/common-ggml.cpp +238 -0
  55. data/ext/sources/examples/common-ggml.h +18 -0
  56. data/ext/sources/examples/common-sdl.cpp +227 -0
  57. data/ext/sources/examples/common-sdl.h +49 -0
  58. data/ext/sources/examples/common-whisper.cpp +175 -0
  59. data/ext/sources/examples/common-whisper.h +24 -0
  60. data/ext/sources/examples/common.cpp +675 -0
  61. data/ext/sources/examples/common.h +322 -0
  62. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  63. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  64. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  65. data/ext/sources/examples/generate-karaoke.sh +57 -0
  66. data/ext/sources/examples/grammar-parser.cpp +423 -0
  67. data/ext/sources/examples/grammar-parser.h +29 -0
  68. data/ext/sources/examples/helpers.js +191 -0
  69. data/ext/sources/examples/json.hpp +24596 -0
  70. data/ext/sources/examples/livestream.sh +112 -0
  71. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  72. data/ext/sources/examples/lsp/lsp.cpp +469 -0
  73. data/ext/sources/examples/lsp/whisper.vim +362 -0
  74. data/ext/sources/examples/miniaudio.h +93468 -0
  75. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  76. data/ext/sources/examples/python/whisper_processor.py +54 -0
  77. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  78. data/ext/sources/examples/quantize/quantize.cpp +226 -0
  79. data/ext/sources/examples/server/CMakeLists.txt +15 -0
  80. data/ext/sources/examples/server/bench.js +29 -0
  81. data/ext/sources/examples/server/httplib.h +10497 -0
  82. data/ext/sources/examples/server/server.cpp +1238 -0
  83. data/ext/sources/examples/server.py +115 -0
  84. data/ext/sources/examples/stb_vorbis.c +5584 -0
  85. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  86. data/ext/sources/examples/stream/stream.cpp +435 -0
  87. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  88. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  89. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  90. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  91. data/ext/sources/examples/sycl/build.sh +22 -0
  92. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  93. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  94. data/ext/sources/examples/talk-llama/CMakeLists.txt +43 -0
  95. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  96. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  97. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  98. data/ext/sources/examples/talk-llama/llama-arch.cpp +1914 -0
  99. data/ext/sources/examples/talk-llama/llama-arch.h +464 -0
  100. data/ext/sources/examples/talk-llama/llama-batch.cpp +843 -0
  101. data/ext/sources/examples/talk-llama/llama-batch.h +147 -0
  102. data/ext/sources/examples/talk-llama/llama-chat.cpp +685 -0
  103. data/ext/sources/examples/talk-llama/llama-chat.h +59 -0
  104. data/ext/sources/examples/talk-llama/llama-context.cpp +2845 -0
  105. data/ext/sources/examples/talk-llama/llama-context.h +297 -0
  106. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  107. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  108. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  109. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  110. data/ext/sources/examples/talk-llama/llama-graph.cpp +1693 -0
  111. data/ext/sources/examples/talk-llama/llama-graph.h +710 -0
  112. data/ext/sources/examples/talk-llama/llama-hparams.cpp +103 -0
  113. data/ext/sources/examples/talk-llama/llama-hparams.h +207 -0
  114. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  115. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  116. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  117. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  118. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  119. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  120. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +44 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +439 -0
  124. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  125. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  126. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  127. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  128. data/ext/sources/examples/talk-llama/llama-memory.cpp +59 -0
  129. data/ext/sources/examples/talk-llama/llama-memory.h +116 -0
  130. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  131. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  132. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1163 -0
  133. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  134. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +282 -0
  135. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  136. data/ext/sources/examples/talk-llama/llama-model.cpp +15114 -0
  137. data/ext/sources/examples/talk-llama/llama-model.h +452 -0
  138. data/ext/sources/examples/talk-llama/llama-quant.cpp +1049 -0
  139. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  140. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  141. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  142. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3377 -0
  143. data/ext/sources/examples/talk-llama/llama-vocab.h +132 -0
  144. data/ext/sources/examples/talk-llama/llama.cpp +358 -0
  145. data/ext/sources/examples/talk-llama/llama.h +1484 -0
  146. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  147. data/ext/sources/examples/talk-llama/speak +40 -0
  148. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  149. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  150. data/ext/sources/examples/talk-llama/talk-llama.cpp +810 -0
  151. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  152. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  153. data/ext/sources/examples/talk-llama/unicode.cpp +854 -0
  154. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  155. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  156. data/ext/sources/examples/vad-speech-segments/speech.cpp +149 -0
  157. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  158. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  159. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  160. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  161. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  162. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  163. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  164. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  165. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +251 -0
  166. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  167. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  168. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  169. data/ext/sources/ggml/CMakeLists.txt +435 -0
  170. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  171. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  172. data/ext/sources/ggml/cmake/common.cmake +50 -0
  173. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  174. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +10 -8
  176. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +11 -1
  178. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  179. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  180. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  181. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  182. data/ext/{ggml → sources/ggml}/include/ggml.h +325 -269
  183. data/ext/sources/ggml/include/gguf.h +202 -0
  184. data/ext/sources/ggml/src/CMakeLists.txt +404 -0
  185. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  186. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  187. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  188. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +92 -53
  189. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +69 -34
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  191. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +75 -0
  192. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  195. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  196. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  197. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +140 -1
  198. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +588 -146
  199. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  200. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  201. data/ext/{ggml → sources/ggml}/src/ggml-common.h +16 -8
  202. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +597 -0
  203. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +3 -2
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +11 -10
  205. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  208. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  209. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  210. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  211. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  212. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  213. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  214. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  215. data/ext/{ggml/src/ggml-cpu/cpu-feats-x86.cpp → sources/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp} +5 -1
  216. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  217. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +3285 -0
  218. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  219. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  220. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  221. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  222. data/ext/sources/ggml/src/ggml-cpu/common.h +73 -0
  223. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +172 -41
  224. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3551 -0
  225. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +78 -25
  226. data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.cpp → sources/ggml/src/ggml-cpu/hbm.cpp} +1 -1
  227. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  228. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  229. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  230. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  231. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3594 -0
  232. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +19 -0
  233. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +9786 -0
  234. data/ext/sources/ggml/src/ggml-cpu/ops.h +118 -0
  235. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  236. data/ext/{ggml/src/ggml-cpu/ggml-cpu-quants.h → sources/ggml/src/ggml-cpu/quants.h} +26 -0
  237. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  238. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  239. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +1184 -0
  240. data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.cpp → sources/ggml/src/ggml-cpu/traits.cpp} +1 -1
  241. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  242. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  243. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +345 -0
  244. data/ext/sources/ggml/src/ggml-cpu/vec.h +1027 -0
  245. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  246. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  247. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  248. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  249. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  250. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  251. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  252. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  253. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  254. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  255. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  256. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  257. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/common.cuh +851 -0
  259. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  260. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  262. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  264. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  266. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  267. data/ext/sources/ggml/src/ggml-cuda/convert.cu +752 -0
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +31 -0
  269. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  270. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  271. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  273. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  275. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  276. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  277. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  278. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1474 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  287. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +638 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  289. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  290. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  291. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  292. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  293. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3647 -0
  294. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  295. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  296. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  297. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  298. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  299. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  300. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  301. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  302. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +506 -0
  304. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +11 -0
  305. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  307. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  308. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  309. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  310. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  312. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  313. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  314. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  315. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  316. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  317. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  318. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  319. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  320. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  321. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  322. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  323. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  324. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  325. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  326. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  327. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +155 -0
  328. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  329. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  330. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +26 -0
  332. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +4 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  430. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  432. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  433. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  434. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  436. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  437. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  438. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  439. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  440. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  441. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  442. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  443. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  444. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  445. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  446. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  447. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  448. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  449. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  450. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  451. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  452. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  453. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  454. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  455. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  456. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  457. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  458. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  459. data/ext/sources/ggml/src/ggml-cuda/unary.cu +378 -0
  460. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +66 -0
  461. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  462. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  463. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  464. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  465. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  466. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  467. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  468. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  469. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +135 -0
  470. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +147 -158
  471. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  481. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  482. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  483. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  484. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  485. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  486. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  487. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  488. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  489. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  490. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  491. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  492. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  493. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  494. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  495. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  496. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  497. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  498. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  499. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  500. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  501. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  502. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  503. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  504. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  505. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  506. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  507. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  508. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  509. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +121 -0
  510. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +649 -0
  511. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2504 -1108
  512. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +2102 -1463
  513. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  514. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  515. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  516. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +110 -0
  517. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +6494 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  526. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  527. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  528. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  529. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  530. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  531. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  532. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  533. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  534. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  535. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  536. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  537. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  538. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  539. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  540. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  541. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  542. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  543. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  544. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  545. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  546. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  547. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  548. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  549. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  550. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  551. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  552. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  553. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  554. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  555. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  556. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  557. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  558. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  559. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  560. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  561. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  562. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  563. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  564. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  565. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  566. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  567. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  568. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  569. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +120 -128
  570. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  571. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +494 -84
  572. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  573. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  574. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +344 -0
  575. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  576. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  577. data/ext/sources/ggml/src/ggml-sycl/common.hpp +561 -0
  578. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +56 -70
  579. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  580. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +8 -12
  581. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  582. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +575 -0
  583. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  584. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +839 -0
  585. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  586. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +823 -0
  587. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +188 -67
  588. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  589. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2987 -0
  590. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1120 -0
  591. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +84 -0
  592. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +102 -0
  593. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +212 -0
  594. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  595. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1197 -1295
  596. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  597. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  598. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  599. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  600. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +60 -81
  601. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  602. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1065 -0
  603. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  604. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +482 -0
  605. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  606. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  607. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  608. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  609. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +111 -0
  610. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +472 -0
  611. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  612. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +38 -28
  613. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  614. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +15 -0
  615. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +26 -0
  616. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +6 -11
  617. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  618. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1307 -0
  619. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +289 -0
  620. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +200 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  623. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3822 -1335
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +31 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +61 -0
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  740. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +203 -36
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  743. data/ext/{ggml → sources/ggml}/src/ggml.c +918 -1782
  744. data/ext/sources/ggml/src/ggml.cpp +26 -0
  745. data/ext/sources/ggml/src/gguf.cpp +1351 -0
  746. data/ext/{include → sources/include}/whisper.h +70 -2
  747. data/ext/sources/src/CMakeLists.txt +145 -0
  748. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  749. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  750. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  751. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +36 -10
  752. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  753. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +29 -3
  754. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  755. data/ext/sources/src/whisper-arch.h +197 -0
  756. data/ext/{src → sources/src}/whisper.cpp +1966 -386
  757. data/ext/sources/tests/CMakeLists.txt +105 -0
  758. data/ext/sources/tests/earnings21/eval.mk +58 -0
  759. data/ext/sources/tests/earnings21/eval.py +68 -0
  760. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  761. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  762. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  763. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  764. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  765. data/ext/sources/tests/en-0-ref.txt +1 -0
  766. data/ext/sources/tests/en-1-ref.txt +1 -0
  767. data/ext/sources/tests/en-2-ref.txt +1 -0
  768. data/ext/sources/tests/es-0-ref.txt +1 -0
  769. data/ext/sources/tests/librispeech/eval.mk +39 -0
  770. data/ext/sources/tests/librispeech/eval.py +47 -0
  771. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  772. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  773. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  774. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  775. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  776. data/ext/sources/tests/run-tests.sh +130 -0
  777. data/ext/sources/tests/test-c.c +3 -0
  778. data/ext/sources/tests/test-vad-full.cpp +54 -0
  779. data/ext/sources/tests/test-vad.cpp +83 -0
  780. data/ext/sources/tests/test-whisper.js +58 -0
  781. data/extsources.rb +39 -5
  782. data/lib/whisper/context.rb +15 -0
  783. data/lib/whisper/model/uri.rb +202 -126
  784. data/lib/whisper/segment.rb +58 -0
  785. data/sig/whisper.rbs +510 -0
  786. data/test/helper.rb +24 -0
  787. data/{tests → test}/test_callback.rb +45 -3
  788. data/{tests → test}/test_error.rb +2 -2
  789. data/{tests → test}/test_model.rb +47 -0
  790. data/test/test_package.rb +51 -0
  791. data/test/test_params.rb +297 -0
  792. data/test/test_segment.rb +146 -0
  793. data/test/test_vad.rb +19 -0
  794. data/test/test_vad_params.rb +103 -0
  795. data/{tests → test}/test_whisper.rb +106 -36
  796. data/whispercpp.gemspec +5 -5
  797. metadata +837 -134
  798. data/ext/cpu.mk +0 -9
  799. data/ext/examples/dr_wav.h +0 -8815
  800. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  801. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  802. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  803. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -10835
  804. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  805. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  806. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  807. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  808. data/ext/ggml/src/ggml-sycl/convert.cpp +0 -547
  809. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  810. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  811. data/ext/ggml/src/ggml-sycl/mmvq.cpp +0 -1015
  812. data/ext/ggml/src/ggml-sycl/norm.cpp +0 -378
  813. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  814. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  815. data/ext/metal-embed.mk +0 -17
  816. data/ext/metal.mk +0 -6
  817. data/ext/ruby_whisper.cpp +0 -1909
  818. data/ext/scripts/get-flags.mk +0 -38
  819. data/lib/whisper.rb +0 -2
  820. data/tests/helper.rb +0 -7
  821. data/tests/test_package.rb +0 -31
  822. data/tests/test_params.rb +0 -160
  823. data/tests/test_segment.rb +0 -83
  824. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  825. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  826. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  827. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  828. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  829. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  830. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  831. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  832. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  833. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  834. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  835. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  836. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  837. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  838. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  839. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  840. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  841. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  842. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  843. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  844. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  845. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  846. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.h → sources/ggml/src/ggml-cpu/hbm.h} +0 -0
  847. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.h → sources/ggml/src/ggml-cpu/traits.h} +0 -0
  848. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  849. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  850. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  851. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  852. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  853. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  854. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
  855. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  856. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  857. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
@@ -4,6 +4,7 @@
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-threading.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml.h"
8
9
 
9
10
  // FIXME: required here for quantization functions
@@ -60,15 +61,17 @@
60
61
  #define m512i(p) (__m512i)(p)
61
62
  #endif
62
63
 
63
- // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
64
- float ggml_table_f32_f16[1 << 16];
64
+ #if defined(__linux__) || \
65
+ defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
66
+ (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
65
67
 
66
- #if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
67
- (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
68
68
  #include <unistd.h>
69
69
  #include <sys/types.h>
70
70
  #include <sys/stat.h>
71
71
  #include <sys/wait.h>
72
+ #if defined(__linux__)
73
+ #include <sys/prctl.h>
74
+ #endif
72
75
 
73
76
  #if defined(__ANDROID__)
74
77
  #include <unwind.h>
@@ -127,11 +130,46 @@ static void ggml_print_backtrace_symbols(void) {
127
130
  }
128
131
  #endif
129
132
 
130
- static void ggml_print_backtrace(void) {
131
- char attach[32];
132
- snprintf(attach, sizeof(attach), "attach %d", getpid());
133
- int pid = fork();
134
- if (pid == 0) {
133
+ void ggml_print_backtrace(void) {
134
+ const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
135
+ if (GGML_NO_BACKTRACE) {
136
+ return;
137
+ }
138
+ #if defined(__linux__)
139
+ FILE * f = fopen("/proc/self/status", "r");
140
+ size_t size = 0;
141
+ char * line = NULL;
142
+ ssize_t length = 0;
143
+ while ((length = getline(&line, &size, f)) > 0) {
144
+ if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
145
+ (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
146
+ // Already being debugged, and the breakpoint is the later abort()
147
+ free(line);
148
+ fclose(f);
149
+ return;
150
+ }
151
+ }
152
+ free(line);
153
+ fclose(f);
154
+ int lock[2] = { -1, -1 };
155
+ (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
156
+ #endif
157
+ const int parent_pid = getpid();
158
+ const int child_pid = fork();
159
+ if (child_pid < 0) { // error
160
+ #if defined(__linux__)
161
+ close(lock[1]);
162
+ close(lock[0]);
163
+ #endif
164
+ return;
165
+ } else if (child_pid == 0) { // child
166
+ char attach[32];
167
+ snprintf(attach, sizeof(attach), "attach %d", parent_pid);
168
+ #if defined(__linux__)
169
+ close(lock[1]);
170
+ (void) !read(lock[0], lock, 1);
171
+ close(lock[0]);
172
+ #endif
135
173
  // try gdb
136
174
  execlp("gdb", "gdb", "--batch",
137
175
  "-ex", "set style enabled on",
@@ -144,22 +182,22 @@ static void ggml_print_backtrace(void) {
144
182
  execlp("lldb", "lldb", "--batch",
145
183
  "-o", "bt",
146
184
  "-o", "quit",
147
- "-p", attach,
185
+ "-p", &attach[sizeof("attach ") - 1],
148
186
  (char *) NULL);
149
- exit(EXIT_FAILURE);
150
- } else {
151
- int wstatus;
152
- waitpid(pid, &wstatus, 0);
153
- if (WIFEXITED(wstatus)) {
154
- if (WEXITSTATUS(wstatus) == EXIT_FAILURE) {
155
- // gdb failed, fallback to backtrace_symbols
156
- ggml_print_backtrace_symbols();
157
- }
158
- }
187
+ // gdb failed, fallback to backtrace_symbols
188
+ ggml_print_backtrace_symbols();
189
+ _Exit(0);
190
+ } else { // parent
191
+ #if defined(__linux__)
192
+ prctl(PR_SET_PTRACER, child_pid);
193
+ close(lock[1]);
194
+ close(lock[0]);
195
+ #endif
196
+ waitpid(child_pid, NULL, 0);
159
197
  }
160
198
  }
161
199
  #else
162
- static void ggml_print_backtrace(void) {
200
+ void ggml_print_backtrace(void) {
163
201
  // platform not supported
164
202
  }
165
203
  #endif
@@ -180,6 +218,8 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
180
218
  abort();
181
219
  }
182
220
 
221
+ // ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
222
+
183
223
  //
184
224
  // logging
185
225
  //
@@ -236,7 +276,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
236
276
 
237
277
 
238
278
  void * ggml_aligned_malloc(size_t size) {
279
+ #if defined(__s390x__)
280
+ const int alignment = 256;
281
+ #else
239
282
  const int alignment = 64;
283
+ #endif
240
284
 
241
285
  #if defined(_MSC_VER) || defined(__MINGW32__)
242
286
  return _aligned_malloc(size, alignment);
@@ -374,58 +418,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
374
418
  }
375
419
  }
376
420
 
377
- // FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
378
- // currently, the ggml_cpu_has_* functions are entirely compile-time
379
421
  void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
380
- int64_t i = 0;
381
- #if defined(__F16C__)
382
- //if (ggml_cpu_has_f16c()) {
383
- for (; i + 7 < n; i += 8) {
384
- __m256 x_vec = _mm256_loadu_ps(x + i);
385
- __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
386
- _mm_storeu_si128((__m128i *)(y + i), y_vec);
387
- }
388
- for(; i + 3 < n; i += 4) {
389
- __m128 x_vec = _mm_loadu_ps(x + i);
390
- __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
391
- _mm_storel_epi64((__m128i *)(y + i), y_vec);
392
- }
393
- //}
394
- #endif
395
- for (; i < n; i++) {
422
+ int i = 0;
423
+ for (; i < n; ++i) {
396
424
  y[i] = GGML_FP32_TO_FP16(x[i]);
397
425
  }
398
426
  }
399
427
 
400
428
  void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
401
- int64_t i = 0;
402
- #if defined(__AVX512F__)
403
- //if (ggml_cpu_has_avx512()) {
404
- for (; i + 16 <= n; i += 16) {
405
- _mm512_storeu_ps(y + i,
406
- _mm512_castsi512_ps(
407
- _mm512_slli_epi32(
408
- _mm512_cvtepu16_epi32(
409
- _mm256_loadu_si256(
410
- (const __m256i *)(x + i))),
411
- 16)));
412
- }
413
- //}
414
- #endif
415
- #if defined(__AVX2__)
416
- //if (ggml_cpu_has_avx2()) {
417
- for (; i + 8 <= n; i += 8) {
418
- _mm256_storeu_ps(y + i,
419
- _mm256_castsi256_ps(
420
- _mm256_slli_epi32(
421
- _mm256_cvtepu16_epi32(
422
- _mm_loadu_si128(
423
- (const __m128i *)(x + i))),
424
- 16)));
425
- }
426
- //}
427
- #endif
428
- for (; i < n; i++) {
429
+ int i = 0;
430
+ for (; i < n; ++i) {
429
431
  y[i] = GGML_BF16_TO_FP32(x[i]);
430
432
  }
431
433
  }
@@ -557,9 +559,9 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
557
559
  #endif
558
560
 
559
561
  }
560
- static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
561
- static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
562
- static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
562
+ static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
563
+ static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
564
+ static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
563
565
 
564
566
  static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
565
567
  [GGML_TYPE_I8] = {
@@ -883,12 +885,6 @@ struct ggml_context {
883
885
  struct ggml_object * objects_end;
884
886
  };
885
887
 
886
- struct ggml_context_container {
887
- bool used;
888
-
889
- struct ggml_context context;
890
- };
891
-
892
888
  //
893
889
  // data types
894
890
  //
@@ -921,6 +917,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
921
917
  "RMS_NORM",
922
918
  "RMS_NORM_BACK",
923
919
  "GROUP_NORM",
920
+ "L2_NORM",
924
921
 
925
922
  "MUL_MAT",
926
923
  "MUL_MAT_ID",
@@ -936,6 +933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
936
933
  "TRANSPOSE",
937
934
  "GET_ROWS",
938
935
  "GET_ROWS_BACK",
936
+ "SET_ROWS",
939
937
  "DIAG",
940
938
  "DIAG_MASK_INF",
941
939
  "DIAG_MASK_ZERO",
@@ -947,6 +945,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
947
945
  "CONV_TRANSPOSE_1D",
948
946
  "IM2COL",
949
947
  "IM2COL_BACK",
948
+ "CONV_2D",
949
+ "CONV_2D_DW",
950
950
  "CONV_TRANSPOSE_2D",
951
951
  "POOL_1D",
952
952
  "POOL_2D",
@@ -954,6 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
954
954
  "UPSCALE",
955
955
  "PAD",
956
956
  "PAD_REFLECT_1D",
957
+ "ROLL",
957
958
  "ARANGE",
958
959
  "TIMESTEP_EMBEDDING",
959
960
  "ARGSORT",
@@ -968,26 +969,25 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
968
969
  "GET_REL_POS",
969
970
  "ADD_REL_POS",
970
971
  "RWKV_WKV6",
972
+ "GATED_LINEAR_ATTN",
973
+ "RWKV_WKV7",
971
974
 
972
975
  "UNARY",
973
976
 
974
- "MAP_UNARY",
975
- "MAP_BINARY",
976
-
977
- "MAP_CUSTOM1_F32",
978
- "MAP_CUSTOM2_F32",
979
- "MAP_CUSTOM3_F32",
980
-
981
977
  "MAP_CUSTOM1",
982
978
  "MAP_CUSTOM2",
983
979
  "MAP_CUSTOM3",
984
980
 
981
+ "CUSTOM",
982
+
985
983
  "CROSS_ENTROPY_LOSS",
986
984
  "CROSS_ENTROPY_LOSS_BACK",
987
985
  "OPT_STEP_ADAMW",
986
+
987
+ "GLU",
988
988
  };
989
989
 
990
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
990
+ static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
991
991
 
992
992
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
993
993
  "none",
@@ -1017,6 +1017,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1017
1017
  "rms_norm(x)",
1018
1018
  "rms_norm_back(x)",
1019
1019
  "group_norm(x)",
1020
+ "l2_norm(x)",
1020
1021
 
1021
1022
  "X*Y",
1022
1023
  "X[i]*Y",
@@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1032
1033
  "transpose(x)",
1033
1034
  "get_rows(x)",
1034
1035
  "get_rows_back(x)",
1036
+ "set_rows(x)",
1035
1037
  "diag(x)",
1036
1038
  "diag_mask_inf(x)",
1037
1039
  "diag_mask_zero(x)",
@@ -1043,6 +1045,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1043
1045
  "conv_transpose_1d(x)",
1044
1046
  "im2col(x)",
1045
1047
  "im2col_back(x)",
1048
+ "conv_2d(x)",
1049
+ "conv_2d_dw(x)",
1046
1050
  "conv_transpose_2d(x)",
1047
1051
  "pool_1d(x)",
1048
1052
  "pool_2d(x)",
@@ -1050,6 +1054,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1050
1054
  "upscale(x)",
1051
1055
  "pad(x)",
1052
1056
  "pad_reflect_1d(x)",
1057
+ "roll(x)",
1053
1058
  "arange(start, stop, step)",
1054
1059
  "timestep_embedding(timesteps, dim, max_period)",
1055
1060
  "argsort(x)",
@@ -1064,26 +1069,25 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1064
1069
  "get_rel_pos(x)",
1065
1070
  "add_rel_pos(x)",
1066
1071
  "rwkv_wkv6(k, v, r, tf, td, s)",
1072
+ "gated_linear_attn(k, v, q, gate, s)",
1073
+ "rwkv_wkv7(r, w, k, v, a, b, s)",
1067
1074
 
1068
1075
  "unary(x)",
1069
1076
 
1070
- "f(x)",
1071
- "f(x,y)",
1072
-
1073
- "custom_f32(x)",
1074
- "custom_f32(x,y)",
1075
- "custom_f32(x,y,z)",
1077
+ "map_custom(x)",
1078
+ "map_custom(x,y)",
1079
+ "map_custom(x,y,z)",
1076
1080
 
1077
1081
  "custom(x)",
1078
- "custom(x,y)",
1079
- "custom(x,y,z)",
1080
1082
 
1081
1083
  "cross_entropy_loss(x,y)",
1082
1084
  "cross_entropy_loss_back(x,y)",
1083
1085
  "adamw(x)",
1086
+
1087
+ "glu(x)",
1084
1088
  };
1085
1089
 
1086
- static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1090
+ static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
1087
1091
 
1088
1092
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1089
1093
 
@@ -1103,9 +1107,19 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1103
1107
  "HARDSWISH",
1104
1108
  "HARDSIGMOID",
1105
1109
  "EXP",
1110
+ "GELU_ERF",
1111
+ };
1112
+
1113
+ static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1114
+
1115
+
1116
+ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1117
+ "REGLU",
1118
+ "GEGLU",
1119
+ "SWIGLU",
1106
1120
  };
1107
1121
 
1108
- static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
1122
+ static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
1109
1123
 
1110
1124
 
1111
1125
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -1145,6 +1159,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
1145
1159
  }
1146
1160
 
1147
1161
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
1162
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1163
+ if (tensor->ne[i] <= 0) {
1164
+ return 0;
1165
+ }
1166
+ }
1167
+
1148
1168
  size_t nbytes;
1149
1169
  const size_t blck_size = ggml_blck_size(tensor->type);
1150
1170
  if (blck_size == 1) {
@@ -1204,11 +1224,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
1204
1224
  return GGML_UNARY_OP_NAME[op];
1205
1225
  }
1206
1226
 
1227
+ const char * ggml_glu_op_name(enum ggml_glu_op op) {
1228
+ return GGML_GLU_OP_NAME[op];
1229
+ }
1230
+
1207
1231
  const char * ggml_op_desc(const struct ggml_tensor * t) {
1208
1232
  if (t->op == GGML_OP_UNARY) {
1209
1233
  enum ggml_unary_op uop = ggml_get_unary_op(t);
1210
1234
  return ggml_unary_op_name(uop);
1211
1235
  }
1236
+ if (t->op == GGML_OP_GLU) {
1237
+ enum ggml_glu_op gop = ggml_get_glu_op(t);
1238
+ return ggml_glu_op_name(gop);
1239
+ }
1212
1240
  return ggml_op_name(t->op);
1213
1241
  }
1214
1242
 
@@ -1328,12 +1356,29 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
1328
1356
  return ggml_is_contiguous_n(tensor, 2);
1329
1357
  }
1330
1358
 
1359
+ bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
1360
+ return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
1361
+ }
1362
+
1331
1363
  bool ggml_is_permuted(const struct ggml_tensor * tensor) {
1332
1364
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1333
1365
 
1334
1366
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1335
1367
  }
1336
1368
 
1369
+ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1370
+ return
1371
+ tensor->nb[0] > tensor->nb[2] &&
1372
+ tensor->nb[1] > tensor->nb[0] &&
1373
+ tensor->nb[2] == ggml_type_size(tensor->type);
1374
+ }
1375
+
1376
+ bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
1377
+ return
1378
+ tensor->ne[0] == ggml_blck_size(tensor->type) ||
1379
+ tensor->nb[0] == ggml_type_size(tensor->type);
1380
+ }
1381
+
1337
1382
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1338
1383
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1339
1384
 
@@ -1373,7 +1418,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
1373
1418
  (t0->nb[3] == t1->nb[3]);
1374
1419
  }
1375
1420
 
1376
- // check if t1 can be represented as a repeatition of t0
1421
+ // check if t1 can be represented as a repetition of t0
1377
1422
  bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
1378
1423
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1379
1424
 
@@ -1405,14 +1450,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
1405
1450
  // initialize time system (required on Windows)
1406
1451
  ggml_time_init();
1407
1452
 
1408
- for (int i = 0; i < (1 << 16); ++i) {
1409
- union {
1410
- uint16_t u16;
1411
- ggml_fp16_t fp16;
1412
- } u = {i};
1413
- ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
1414
- }
1415
-
1416
1453
  is_first_call = false;
1417
1454
  }
1418
1455
 
@@ -1588,15 +1625,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
1588
1625
 
1589
1626
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
1590
1627
 
1591
- #ifdef __clang__
1592
- // temporary until ggml_tensor::backend is removed
1593
- #pragma clang diagnostic push
1594
- #pragma clang diagnostic ignored "-Wdeprecated-declarations"
1595
- #endif
1596
-
1597
1628
  *result = (struct ggml_tensor) {
1598
1629
  /*.type =*/ type,
1599
- /*.backend =*/ GGML_BACKEND_TYPE_CPU,
1600
1630
  /*.buffer =*/ NULL,
1601
1631
  /*.ne =*/ { 1, 1, 1, 1 },
1602
1632
  /*.nb =*/ { 0, 0, 0, 0 },
@@ -1612,10 +1642,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
1612
1642
  /*.padding =*/ { 0 },
1613
1643
  };
1614
1644
 
1615
- #ifdef __clang__
1616
- #pragma clang diagnostic pop
1617
- #endif
1618
-
1619
1645
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
1620
1646
  //GGML_ASSERT_ALIGNED(result->data);
1621
1647
 
@@ -1727,6 +1753,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
1727
1753
  return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1728
1754
  }
1729
1755
 
1756
+ enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1757
+ GGML_ASSERT(tensor->op == GGML_OP_GLU);
1758
+ return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1759
+ }
1760
+
1730
1761
  const char * ggml_get_name(const struct ggml_tensor * tensor) {
1731
1762
  return tensor->name;
1732
1763
  }
@@ -2309,6 +2340,26 @@ struct ggml_tensor * ggml_repeat(
2309
2340
  return result;
2310
2341
  }
2311
2342
 
2343
+ struct ggml_tensor * ggml_repeat_4d(
2344
+ struct ggml_context * ctx,
2345
+ struct ggml_tensor * a,
2346
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
2347
+ const bool can_repeat = ggml_is_empty(a) || (
2348
+ (ne0 % a->ne[0] == 0) &&
2349
+ (ne1 % a->ne[1] == 0) &&
2350
+ (ne2 % a->ne[2] == 0) &&
2351
+ (ne3 % a->ne[3] == 0)
2352
+ );
2353
+ GGML_ASSERT(can_repeat);
2354
+
2355
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
2356
+
2357
+ result->op = GGML_OP_REPEAT;
2358
+ result->src[0] = a;
2359
+
2360
+ return result;
2361
+ }
2362
+
2312
2363
  // ggml_repeat_back
2313
2364
 
2314
2365
  struct ggml_tensor * ggml_repeat_back(
@@ -2333,6 +2384,7 @@ struct ggml_tensor * ggml_concat(
2333
2384
  struct ggml_tensor * b,
2334
2385
  int dim) {
2335
2386
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
2387
+ GGML_ASSERT(a->type == b->type);
2336
2388
 
2337
2389
  int64_t ne[GGML_MAX_DIMS];
2338
2390
  for (int d = 0; d < GGML_MAX_DIMS; ++d) {
@@ -2498,6 +2550,20 @@ struct ggml_tensor * ggml_gelu_inplace(
2498
2550
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
2499
2551
  }
2500
2552
 
2553
+ // ggml_gelu_erf
2554
+
2555
+ struct ggml_tensor * ggml_gelu_erf(
2556
+ struct ggml_context * ctx,
2557
+ struct ggml_tensor * a) {
2558
+ return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
2559
+ }
2560
+
2561
+ struct ggml_tensor * ggml_gelu_erf_inplace(
2562
+ struct ggml_context * ctx,
2563
+ struct ggml_tensor * a) {
2564
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
2565
+ }
2566
+
2501
2567
  // ggml_gelu_quick
2502
2568
 
2503
2569
  struct ggml_tensor * ggml_gelu_quick(
@@ -2571,6 +2637,114 @@ struct ggml_tensor * ggml_exp_inplace(
2571
2637
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2572
2638
  }
2573
2639
 
2640
+ // ggml_glu
2641
+
2642
+ static struct ggml_tensor * ggml_glu_impl(
2643
+ struct ggml_context * ctx,
2644
+ struct ggml_tensor * a,
2645
+ struct ggml_tensor * b,
2646
+ enum ggml_glu_op op,
2647
+ bool swapped) {
2648
+ GGML_ASSERT(ggml_is_contiguous_1(a));
2649
+
2650
+ if (b) {
2651
+ GGML_ASSERT(ggml_is_contiguous_1(b));
2652
+ GGML_ASSERT(ggml_are_same_shape(a, b));
2653
+ GGML_ASSERT(a->type == b->type);
2654
+ }
2655
+
2656
+ int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2657
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2658
+
2659
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
2660
+ ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2661
+
2662
+ result->op = GGML_OP_GLU;
2663
+ result->src[0] = a;
2664
+ result->src[1] = b;
2665
+
2666
+ return result;
2667
+ }
2668
+
2669
+ struct ggml_tensor * ggml_glu(
2670
+ struct ggml_context * ctx,
2671
+ struct ggml_tensor * a,
2672
+ enum ggml_glu_op op,
2673
+ bool swapped) {
2674
+ return ggml_glu_impl(ctx, a, NULL, op, swapped);
2675
+ }
2676
+
2677
+ struct ggml_tensor * ggml_glu_split(
2678
+ struct ggml_context * ctx,
2679
+ struct ggml_tensor * a,
2680
+ struct ggml_tensor * b,
2681
+ enum ggml_glu_op op) {
2682
+ return ggml_glu_impl(ctx, a, b, op, false);
2683
+ }
2684
+
2685
+ // ggml_reglu
2686
+
2687
+ struct ggml_tensor * ggml_reglu(
2688
+ struct ggml_context * ctx,
2689
+ struct ggml_tensor * a) {
2690
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
2691
+ }
2692
+
2693
+ struct ggml_tensor * ggml_reglu_swapped(
2694
+ struct ggml_context * ctx,
2695
+ struct ggml_tensor * a) {
2696
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
2697
+ }
2698
+
2699
+ struct ggml_tensor * ggml_reglu_split(
2700
+ struct ggml_context * ctx,
2701
+ struct ggml_tensor * a,
2702
+ struct ggml_tensor * b) {
2703
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
2704
+ }
2705
+
2706
+ // ggml_geglu
2707
+
2708
+ struct ggml_tensor * ggml_geglu(
2709
+ struct ggml_context * ctx,
2710
+ struct ggml_tensor * a) {
2711
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
2712
+ }
2713
+
2714
+ struct ggml_tensor * ggml_geglu_swapped(
2715
+ struct ggml_context * ctx,
2716
+ struct ggml_tensor * a) {
2717
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
2718
+ }
2719
+
2720
+ struct ggml_tensor * ggml_geglu_split(
2721
+ struct ggml_context * ctx,
2722
+ struct ggml_tensor * a,
2723
+ struct ggml_tensor * b) {
2724
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
2725
+ }
2726
+
2727
+ // ggml_swiglu
2728
+
2729
+ struct ggml_tensor * ggml_swiglu(
2730
+ struct ggml_context * ctx,
2731
+ struct ggml_tensor * a) {
2732
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
2733
+ }
2734
+
2735
+ struct ggml_tensor * ggml_swiglu_swapped(
2736
+ struct ggml_context * ctx,
2737
+ struct ggml_tensor * a) {
2738
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
2739
+ }
2740
+
2741
+ struct ggml_tensor * ggml_swiglu_split(
2742
+ struct ggml_context * ctx,
2743
+ struct ggml_tensor * a,
2744
+ struct ggml_tensor * b) {
2745
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2746
+ }
2747
+
2574
2748
  // ggml_norm
2575
2749
 
2576
2750
  static struct ggml_tensor * ggml_norm_impl(
@@ -2686,6 +2860,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
2686
2860
  return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
2687
2861
  }
2688
2862
 
2863
+ // ggml_l2_norm
2864
+
2865
+ static struct ggml_tensor * ggml_l2_norm_impl(
2866
+ struct ggml_context * ctx,
2867
+ struct ggml_tensor * a,
2868
+ float eps,
2869
+ bool inplace) {
2870
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2871
+
2872
+ ggml_set_op_params_f32(result, 0, eps);
2873
+
2874
+ result->op = GGML_OP_L2_NORM;
2875
+ result->src[0] = a;
2876
+
2877
+ return result;
2878
+ }
2879
+
2880
+ struct ggml_tensor * ggml_l2_norm(
2881
+ struct ggml_context * ctx,
2882
+ struct ggml_tensor * a,
2883
+ float eps) {
2884
+ return ggml_l2_norm_impl(ctx, a, eps, false);
2885
+ }
2886
+
2887
+ struct ggml_tensor * ggml_l2_norm_inplace(
2888
+ struct ggml_context * ctx,
2889
+ struct ggml_tensor * a,
2890
+ float eps) {
2891
+ return ggml_l2_norm_impl(ctx, a, eps, true);
2892
+ }
2893
+
2689
2894
  // ggml_mul_mat
2690
2895
 
2691
2896
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -2729,11 +2934,11 @@ void ggml_mul_mat_set_prec(
2729
2934
  c = ggml_mul_mat_id(ctx, as, b, ids);
2730
2935
 
2731
2936
  as -> [cols, rows, n_expert]
2732
- ids -> [n_experts_used, n_tokens] (i32)
2733
2937
  b -> [cols, n_expert_used, n_tokens]
2938
+ ids -> [n_expert_used, n_tokens] (i32)
2734
2939
  c -> [rows, n_expert_used, n_tokens]
2735
2940
 
2736
- in b, n_experts_used can be broadcasted to match the n_expert_used of ids
2941
+ in b, n_expert_used can be broadcasted to match the n_expert_used of ids
2737
2942
 
2738
2943
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2739
2944
  */
@@ -3323,6 +3528,35 @@ struct ggml_tensor * ggml_get_rows_back(
3323
3528
  return result;
3324
3529
  }
3325
3530
 
3531
+ // ggml_set_rows
3532
+
3533
+ struct ggml_tensor * ggml_set_rows(
3534
+ struct ggml_context * ctx,
3535
+ struct ggml_tensor * a,
3536
+ struct ggml_tensor * b,
3537
+ struct ggml_tensor * c) {
3538
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
3539
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
3540
+ GGML_ASSERT(a->ne[3] == b->ne[3]);
3541
+ GGML_ASSERT(b->ne[1] == c->ne[0]);
3542
+ GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3543
+ GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3544
+ GGML_ASSERT(c->ne[3] == 1);
3545
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
3546
+ GGML_ASSERT(c->type == GGML_TYPE_I64);
3547
+
3548
+ GGML_ASSERT(ggml_is_contiguous_rows(a));
3549
+ GGML_ASSERT(ggml_is_contiguous_rows(b));
3550
+
3551
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3552
+
3553
+ result->op = GGML_OP_SET_ROWS;
3554
+ result->src[0] = b;
3555
+ result->src[1] = c;
3556
+
3557
+ return result;
3558
+ }
3559
+
3326
3560
  // ggml_diag
3327
3561
 
3328
3562
  struct ggml_tensor * ggml_diag(
@@ -3459,12 +3693,14 @@ struct ggml_tensor * ggml_soft_max_ext(
3459
3693
  return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3460
3694
  }
3461
3695
 
3462
- // ggml_soft_max_back
3696
+ // ggml_soft_max_ext_back
3463
3697
 
3464
- static struct ggml_tensor * ggml_soft_max_back_impl(
3698
+ static struct ggml_tensor * ggml_soft_max_ext_back_impl(
3465
3699
  struct ggml_context * ctx,
3466
3700
  struct ggml_tensor * a,
3467
3701
  struct ggml_tensor * b,
3702
+ float scale,
3703
+ float max_bias,
3468
3704
  bool inplace) {
3469
3705
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3470
3706
 
@@ -3472,21 +3708,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
3472
3708
  result->src[0] = a;
3473
3709
  result->src[1] = b;
3474
3710
 
3711
+ memcpy((float *) result->op_params + 0, &scale, sizeof(float));
3712
+ memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
3713
+
3475
3714
  return result;
3476
3715
  }
3477
3716
 
3478
- struct ggml_tensor * ggml_soft_max_back(
3717
+ struct ggml_tensor * ggml_soft_max_ext_back(
3479
3718
  struct ggml_context * ctx,
3480
3719
  struct ggml_tensor * a,
3481
- struct ggml_tensor * b) {
3482
- return ggml_soft_max_back_impl(ctx, a, b, false);
3720
+ struct ggml_tensor * b,
3721
+ float scale,
3722
+ float max_bias) {
3723
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
3483
3724
  }
3484
3725
 
3485
- struct ggml_tensor * ggml_soft_max_back_inplace(
3726
+ struct ggml_tensor * ggml_soft_max_ext_back_inplace(
3486
3727
  struct ggml_context * ctx,
3487
3728
  struct ggml_tensor * a,
3488
- struct ggml_tensor * b) {
3489
- return ggml_soft_max_back_impl(ctx, a, b, true);
3729
+ struct ggml_tensor * b,
3730
+ float scale,
3731
+ float max_bias) {
3732
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
3490
3733
  }
3491
3734
 
3492
3735
  // ggml_rope
@@ -3704,7 +3947,7 @@ void ggml_rope_yarn_corr_dims(
3704
3947
 
3705
3948
  // ggml_rope_back
3706
3949
 
3707
- struct ggml_tensor * ggml_rope_back(
3950
+ struct ggml_tensor * ggml_rope_ext_back(
3708
3951
  struct ggml_context * ctx,
3709
3952
  struct ggml_tensor * a,
3710
3953
  struct ggml_tensor * b,
@@ -3718,29 +3961,32 @@ struct ggml_tensor * ggml_rope_back(
3718
3961
  float attn_factor,
3719
3962
  float beta_fast,
3720
3963
  float beta_slow) {
3721
- GGML_ASSERT(ggml_is_vector(b));
3722
- GGML_ASSERT(b->type == GGML_TYPE_I32);
3723
- GGML_ASSERT(a->ne[2] == b->ne[0]);
3724
-
3725
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3726
-
3727
- int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3728
- memcpy(params + 5, &freq_base, sizeof(float));
3729
- memcpy(params + 6, &freq_scale, sizeof(float));
3730
- memcpy(params + 7, &ext_factor, sizeof(float));
3731
- memcpy(params + 8, &attn_factor, sizeof(float));
3732
- memcpy(params + 9, &beta_fast, sizeof(float));
3733
- memcpy(params + 10, &beta_slow, sizeof(float));
3734
- ggml_set_op_params(result, params, sizeof(params));
3735
-
3736
- result->op = GGML_OP_ROPE_BACK;
3737
- result->src[0] = a;
3738
- result->src[1] = b;
3739
- result->src[2] = c;
3740
-
3964
+ struct ggml_tensor * result = ggml_rope_ext(
3965
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3966
+ result->op = GGML_OP_ROPE_BACK;
3741
3967
  return result;
3742
3968
  }
3743
3969
 
3970
+ struct ggml_tensor * ggml_rope_multi_back(
3971
+ struct ggml_context * ctx,
3972
+ struct ggml_tensor * a,
3973
+ struct ggml_tensor * b,
3974
+ struct ggml_tensor * c,
3975
+ int n_dims,
3976
+ int sections[4],
3977
+ int mode,
3978
+ int n_ctx_orig,
3979
+ float freq_base,
3980
+ float freq_scale,
3981
+ float ext_factor,
3982
+ float attn_factor,
3983
+ float beta_fast,
3984
+ float beta_slow) {
3985
+ struct ggml_tensor * result = ggml_rope_multi(
3986
+ ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3987
+ result->op = GGML_OP_ROPE_BACK;
3988
+ return result;
3989
+ }
3744
3990
  // ggml_clamp
3745
3991
 
3746
3992
  struct ggml_tensor * ggml_clamp(
@@ -3760,174 +4006,183 @@ struct ggml_tensor * ggml_clamp(
3760
4006
  return result;
3761
4007
  }
3762
4008
 
3763
- // ggml_conv_1d
3764
-
3765
4009
  static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3766
4010
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
3767
4011
  }
3768
4012
 
3769
- GGML_API struct ggml_tensor * ggml_conv_1d(
4013
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
4014
+ // a: [OC,IC, KH, KW]
4015
+ // b: [N, IC, IH, IW]
4016
+ // result: [N, OH, OW, IC*KH*KW]
4017
+ struct ggml_tensor * ggml_im2col(
3770
4018
  struct ggml_context * ctx,
3771
4019
  struct ggml_tensor * a,
3772
4020
  struct ggml_tensor * b,
3773
4021
  int s0,
4022
+ int s1,
3774
4023
  int p0,
3775
- int d0) {
3776
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
4024
+ int p1,
4025
+ int d0,
4026
+ int d1,
4027
+ bool is_2D,
4028
+ enum ggml_type dst_type) {
4029
+ if (is_2D) {
4030
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
4031
+ } else {
4032
+ //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
4033
+ GGML_ASSERT(b->ne[1] == a->ne[1]);
4034
+ GGML_ASSERT(b->ne[3] == 1);
4035
+ }
3777
4036
 
3778
- struct ggml_tensor * result =
3779
- ggml_mul_mat(ctx,
3780
- ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
3781
- ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
3782
-
3783
- result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
4037
+ const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
4038
+ const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
3784
4039
 
3785
- return result;
3786
- }
4040
+ GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
4041
+ GGML_ASSERT((OW > 0) && "b too small compared to a");
3787
4042
 
3788
- // ggml_conv_1d_ph
4043
+ const int64_t ne[4] = {
4044
+ is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
4045
+ OW,
4046
+ is_2D ? OH : b->ne[2],
4047
+ is_2D ? b->ne[3] : 1,
4048
+ };
3789
4049
 
3790
- struct ggml_tensor* ggml_conv_1d_ph(
3791
- struct ggml_context * ctx,
3792
- struct ggml_tensor * a,
3793
- struct ggml_tensor * b,
3794
- int s,
3795
- int d) {
3796
- return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
3797
- }
4050
+ struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
4051
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
4052
+ ggml_set_op_params(result, params, sizeof(params));
3798
4053
 
3799
- // ggml_conv_transpose_1d
4054
+ result->op = GGML_OP_IM2COL;
4055
+ result->src[0] = a;
4056
+ result->src[1] = b;
3800
4057
 
3801
- static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3802
- return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
4058
+ return result;
3803
4059
  }
3804
4060
 
3805
- GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
4061
+ struct ggml_tensor * ggml_im2col_back(
3806
4062
  struct ggml_context * ctx,
3807
4063
  struct ggml_tensor * a,
3808
4064
  struct ggml_tensor * b,
4065
+ int64_t * ne,
3809
4066
  int s0,
4067
+ int s1,
3810
4068
  int p0,
3811
- int d0) {
3812
- GGML_ASSERT(ggml_is_matrix(b));
3813
- GGML_ASSERT(a->ne[2] == b->ne[1]);
3814
- GGML_ASSERT(a->ne[3] == 1);
3815
-
3816
- GGML_ASSERT(p0 == 0);
3817
- GGML_ASSERT(d0 == 1);
3818
-
3819
- const int64_t ne[4] = {
3820
- ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
3821
- a->ne[1], b->ne[2], 1,
3822
- };
4069
+ int p1,
4070
+ int d0,
4071
+ int d1,
4072
+ bool is_2D) {
3823
4073
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3824
-
3825
- int32_t params[] = { s0, p0, d0 };
4074
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
3826
4075
  ggml_set_op_params(result, params, sizeof(params));
3827
4076
 
3828
- result->op = GGML_OP_CONV_TRANSPOSE_1D;
4077
+ result->op = GGML_OP_IM2COL_BACK;
3829
4078
  result->src[0] = a;
3830
4079
  result->src[1] = b;
3831
4080
 
3832
4081
  return result;
3833
4082
  }
3834
4083
 
3835
- // ggml_conv_depthwise
4084
+ // ggml_conv_1d
3836
4085
 
3837
- struct ggml_tensor * ggml_conv_depthwise_2d(
4086
+ struct ggml_tensor * ggml_conv_1d(
3838
4087
  struct ggml_context * ctx,
3839
4088
  struct ggml_tensor * a,
3840
4089
  struct ggml_tensor * b,
3841
4090
  int s0,
3842
- int s1,
3843
4091
  int p0,
3844
- int p1,
3845
- int d0,
3846
- int d1) {
3847
- struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
3848
- struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
3849
- ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
3850
- s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
3851
- struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
4092
+ int d0) {
4093
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
3852
4094
 
3853
- new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
3854
- struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
3855
- result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
4095
+ struct ggml_tensor * result =
4096
+ ggml_mul_mat(ctx,
4097
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
4098
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
4099
+
4100
+ result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
3856
4101
 
3857
4102
  return result;
3858
4103
  }
3859
- // ggml_conv_2d
3860
4104
 
3861
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
3862
- // a: [OC,IC, KH, KW]
3863
- // b: [N, IC, IH, IW]
3864
- // result: [N, OH, OW, IC*KH*KW]
3865
- struct ggml_tensor * ggml_im2col(
4105
+ // ggml_conv_1d_ph
4106
+
4107
+ struct ggml_tensor* ggml_conv_1d_ph(
4108
+ struct ggml_context * ctx,
4109
+ struct ggml_tensor * a,
4110
+ struct ggml_tensor * b,
4111
+ int s,
4112
+ int d) {
4113
+ return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
4114
+ }
4115
+
4116
+ // ggml_conv_1d_dw
4117
+
4118
+ struct ggml_tensor * ggml_conv_1d_dw(
3866
4119
  struct ggml_context * ctx,
3867
4120
  struct ggml_tensor * a,
3868
4121
  struct ggml_tensor * b,
3869
4122
  int s0,
3870
- int s1,
3871
4123
  int p0,
3872
- int p1,
3873
- int d0,
3874
- int d1,
3875
- bool is_2D,
3876
- enum ggml_type dst_type) {
3877
- if(is_2D) {
3878
- GGML_ASSERT(a->ne[2] == b->ne[2]);
3879
- } else {
3880
- GGML_ASSERT(a->ne[1] == b->ne[1]);
3881
- GGML_ASSERT(b->ne[3] == 1);
3882
- }
4124
+ int d0) {
4125
+ struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
4126
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
3883
4127
 
3884
- const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
3885
- const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4128
+ struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
3886
4129
 
3887
- GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
3888
- GGML_ASSERT((OW > 0) && "b too small compared to a");
4130
+ struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
3889
4131
 
3890
- const int64_t ne[4] = {
3891
- is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
3892
- OW,
3893
- is_2D ? OH : b->ne[2],
3894
- is_2D ? b->ne[3] : 1,
3895
- };
4132
+ result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
3896
4133
 
3897
- struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
3898
- int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
3899
- ggml_set_op_params(result, params, sizeof(params));
4134
+ return result;
4135
+ }
3900
4136
 
3901
- result->op = GGML_OP_IM2COL;
3902
- result->src[0] = a;
3903
- result->src[1] = b;
4137
+ // ggml_conv_1d_dw_ph
3904
4138
 
3905
- return result;
4139
+ struct ggml_tensor * ggml_conv_1d_dw_ph(
4140
+ struct ggml_context * ctx,
4141
+ struct ggml_tensor * a,
4142
+ struct ggml_tensor * b,
4143
+ int s0,
4144
+ int d0) {
4145
+ return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
3906
4146
  }
3907
4147
 
3908
- struct ggml_tensor * ggml_im2col_back(
4148
+ // ggml_conv_transpose_1d
4149
+
4150
+ static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
4151
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
4152
+ }
4153
+
4154
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
3909
4155
  struct ggml_context * ctx,
3910
4156
  struct ggml_tensor * a,
3911
4157
  struct ggml_tensor * b,
3912
- int64_t * ne,
3913
4158
  int s0,
3914
- int s1,
3915
4159
  int p0,
3916
- int p1,
3917
- int d0,
3918
- int d1,
3919
- bool is_2D) {
4160
+ int d0) {
4161
+ GGML_ASSERT(ggml_is_matrix(b));
4162
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
4163
+ GGML_ASSERT(a->ne[3] == 1);
4164
+
4165
+ GGML_ASSERT(p0 == 0);
4166
+ GGML_ASSERT(d0 == 1);
4167
+
4168
+ const int64_t ne[4] = {
4169
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
4170
+ a->ne[1], b->ne[2], 1,
4171
+ };
3920
4172
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3921
- int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
4173
+
4174
+ int32_t params[] = { s0, p0, d0 };
3922
4175
  ggml_set_op_params(result, params, sizeof(params));
3923
4176
 
3924
- result->op = GGML_OP_IM2COL_BACK;
4177
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
3925
4178
  result->src[0] = a;
3926
4179
  result->src[1] = b;
3927
4180
 
3928
4181
  return result;
3929
4182
  }
3930
4183
 
4184
+ // ggml_conv_2d
4185
+
3931
4186
  // a: [OC,IC, KH, KW]
3932
4187
  // b: [N, IC, IH, IW]
3933
4188
  // result: [N, OC, OH, OW]
@@ -3973,6 +4228,109 @@ struct ggml_tensor * ggml_conv_2d_s1_ph(
3973
4228
  return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
3974
4229
  }
3975
4230
 
4231
+ // ggml_conv_2d_dw
4232
+
4233
+ struct ggml_tensor * ggml_conv_2d_dw(
4234
+ struct ggml_context * ctx,
4235
+ struct ggml_tensor * a,
4236
+ struct ggml_tensor * b,
4237
+ int s0,
4238
+ int s1,
4239
+ int p0,
4240
+ int p1,
4241
+ int d0,
4242
+ int d1) {
4243
+ struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
4244
+ struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
4245
+ ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
4246
+ s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
4247
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
4248
+
4249
+ new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
4250
+ struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
4251
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
4252
+
4253
+ return result;
4254
+ }
4255
+
4256
+ // ggml_conv_2d_dw_direct
4257
+
4258
+ struct ggml_tensor * ggml_conv_2d_dw_direct(
4259
+ struct ggml_context * ctx,
4260
+ struct ggml_tensor * a,
4261
+ struct ggml_tensor * b,
4262
+ int stride0,
4263
+ int stride1,
4264
+ int pad0,
4265
+ int pad1,
4266
+ int dilation0,
4267
+ int dilation1) {
4268
+ GGML_ASSERT(a->ne[2] == 1);
4269
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
4270
+ int64_t ne[4];
4271
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4272
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4273
+ ne[2] = b->ne[2];
4274
+ ne[3] = b->ne[3];
4275
+
4276
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4277
+
4278
+ if (ggml_is_contiguous_channels(b)) {
4279
+ // Result will be permuted the same way as input (CWHN order)
4280
+ const int64_t type_size = ggml_type_size(result->type);
4281
+ GGML_ASSERT(ggml_blck_size(result->type) == 1);
4282
+ result->nb[0] = result->ne[2] * type_size;
4283
+ result->nb[1] = result->ne[0] * result->nb[0];
4284
+ result->nb[2] = type_size;
4285
+ }
4286
+
4287
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4288
+ ggml_set_op_params(result, params, sizeof(params));
4289
+
4290
+ result->op = GGML_OP_CONV_2D_DW;
4291
+ result->src[0] = a;
4292
+ result->src[1] = b;
4293
+ return result;
4294
+ }
4295
+
4296
+ // ggml_conv_2d_direct
4297
+
4298
+ struct ggml_tensor * ggml_conv_2d_direct(
4299
+ struct ggml_context * ctx,
4300
+ struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4301
+ struct ggml_tensor * b, // input data [W, H, C, N]
4302
+ int s0, // stride dimension 0
4303
+ int s1, // stride dimension 1
4304
+ int p0, // padding dimension 0
4305
+ int p1, // padding dimension 1
4306
+ int d0, // dilation dimension 0
4307
+ int d1) {// dilation dimension 1
4308
+
4309
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
4310
+ //GGML_ASSERT(a->type == b->type);
4311
+
4312
+ int64_t ne[4];
4313
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4314
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4315
+ ne[2] = a->ne[3];
4316
+ ne[3] = b->ne[3];
4317
+
4318
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4319
+
4320
+ ggml_set_op_params_i32(result, 0, s0);
4321
+ ggml_set_op_params_i32(result, 1, s1);
4322
+ ggml_set_op_params_i32(result, 2, p0);
4323
+ ggml_set_op_params_i32(result, 3, p1);
4324
+ ggml_set_op_params_i32(result, 4, d0);
4325
+ ggml_set_op_params_i32(result, 5, d1);
4326
+
4327
+ result->op = GGML_OP_CONV_2D;
4328
+ result->src[0] = a;
4329
+ result->src[1] = b;
4330
+
4331
+ return result;
4332
+ }
4333
+
3976
4334
  // ggml_conv_transpose_2d_p0
3977
4335
 
3978
4336
  static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4089,22 +4447,22 @@ struct ggml_tensor * ggml_pool_2d_back(
4089
4447
  return result;
4090
4448
  }
4091
4449
 
4092
- // ggml_upscale
4450
+ // ggml_upscale / ggml_interpolate
4093
4451
 
4094
- static struct ggml_tensor * ggml_upscale_impl(
4452
+ static struct ggml_tensor * ggml_interpolate_impl(
4095
4453
  struct ggml_context * ctx,
4096
4454
  struct ggml_tensor * a,
4097
- int ne0,
4098
- int ne1,
4099
- int ne2,
4100
- int ne3) {
4101
- GGML_ASSERT(a->ne[0] <= ne0);
4102
- GGML_ASSERT(a->ne[1] <= ne1);
4103
- GGML_ASSERT(a->ne[2] <= ne2);
4104
- GGML_ASSERT(a->ne[3] <= ne3);
4455
+ int64_t ne0,
4456
+ int64_t ne1,
4457
+ int64_t ne2,
4458
+ int64_t ne3,
4459
+ uint32_t mode) {
4460
+ GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
4105
4461
 
4106
4462
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4107
4463
 
4464
+ ggml_set_op_params_i32(result, 0, (int32_t)mode);
4465
+
4108
4466
  result->op = GGML_OP_UPSCALE;
4109
4467
  result->src[0] = a;
4110
4468
 
@@ -4114,8 +4472,10 @@ static struct ggml_tensor * ggml_upscale_impl(
4114
4472
  struct ggml_tensor * ggml_upscale(
4115
4473
  struct ggml_context * ctx,
4116
4474
  struct ggml_tensor * a,
4117
- int scale_factor) {
4118
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4475
+ int scale_factor,
4476
+ enum ggml_scale_mode mode) {
4477
+ GGML_ASSERT(scale_factor > 1);
4478
+ return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4119
4479
  }
4120
4480
 
4121
4481
  struct ggml_tensor * ggml_upscale_ext(
@@ -4124,8 +4484,20 @@ struct ggml_tensor * ggml_upscale_ext(
4124
4484
  int ne0,
4125
4485
  int ne1,
4126
4486
  int ne2,
4127
- int ne3) {
4128
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4487
+ int ne3,
4488
+ enum ggml_scale_mode mode) {
4489
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4490
+ }
4491
+
4492
+ struct ggml_tensor * ggml_interpolate(
4493
+ struct ggml_context * ctx,
4494
+ struct ggml_tensor * a,
4495
+ int64_t ne0,
4496
+ int64_t ne1,
4497
+ int64_t ne2,
4498
+ int64_t ne3,
4499
+ uint32_t mode) {
4500
+ return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4129
4501
  }
4130
4502
 
4131
4503
  // ggml_pad
@@ -4180,6 +4552,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
4180
4552
  return result;
4181
4553
  }
4182
4554
 
4555
+ // ggml_roll
4556
+
4557
+ struct ggml_tensor * ggml_roll(
4558
+ struct ggml_context * ctx,
4559
+ struct ggml_tensor * a,
4560
+ int shift0,
4561
+ int shift1,
4562
+ int shift2,
4563
+ int shift3) {
4564
+ GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
4565
+ GGML_ASSERT(abs(shift0) < a->ne[0]);
4566
+ GGML_ASSERT(abs(shift1) < a->ne[1]);
4567
+ GGML_ASSERT(abs(shift2) < a->ne[2]);
4568
+ GGML_ASSERT(abs(shift3) < a->ne[3]);
4569
+
4570
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
4571
+
4572
+ ggml_set_op_params_i32(result, 0, shift0);
4573
+ ggml_set_op_params_i32(result, 1, shift1);
4574
+ ggml_set_op_params_i32(result, 2, shift2);
4575
+ ggml_set_op_params_i32(result, 3, shift3);
4576
+
4577
+ result->op = GGML_OP_ROLL;
4578
+ result->src[0] = a;
4579
+
4580
+ return result;
4581
+ }
4582
+
4183
4583
  // ggml_arange
4184
4584
 
4185
4585
  struct ggml_tensor * ggml_arange(
@@ -4288,7 +4688,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
4288
4688
  }
4289
4689
 
4290
4690
  // permute(0, 2, 1, 3)
4291
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4691
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4292
4692
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4293
4693
 
4294
4694
  float params[] = { scale, max_bias, logit_softcap };
@@ -4606,15 +5006,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
4606
5006
  GGML_ASSERT(ggml_is_contiguous(state));
4607
5007
 
4608
5008
  const int64_t S = k->ne[0];
4609
- const int64_t H = k->ne[2];
4610
- const int64_t n_tokens = k->ne[3];
5009
+ const int64_t H = k->ne[1];
5010
+ const int64_t n_tokens = k->ne[2];
4611
5011
  const int64_t n_seqs = state->ne[1];
4612
5012
  {
4613
- GGML_ASSERT(k->ne[1] == 1);
4614
- GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
4615
- GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
4616
- // TODO: RWKV v4 and v5
4617
- GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
5013
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5014
+ GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
5015
+ GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
4618
5016
  GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4619
5017
  }
4620
5018
 
@@ -4633,210 +5031,128 @@ struct ggml_tensor * ggml_rwkv_wkv6(
4633
5031
  return result;
4634
5032
  }
4635
5033
 
4636
- // ggml_unary
5034
+ // ggml_gated_linear_attn
4637
5035
 
4638
- static struct ggml_tensor * ggml_unary_impl(
5036
+ struct ggml_tensor * ggml_gated_linear_attn(
4639
5037
  struct ggml_context * ctx,
4640
- struct ggml_tensor * a,
4641
- enum ggml_unary_op op,
4642
- bool inplace) {
4643
- GGML_ASSERT(ggml_is_contiguous_1(a));
4644
-
4645
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4646
-
4647
- ggml_set_op_params_i32(result, 0, (int32_t) op);
4648
-
4649
- result->op = GGML_OP_UNARY;
4650
- result->src[0] = a;
4651
-
4652
- return result;
4653
- }
4654
-
4655
- struct ggml_tensor * ggml_unary(
4656
- struct ggml_context * ctx,
4657
- struct ggml_tensor * a,
4658
- enum ggml_unary_op op) {
4659
- return ggml_unary_impl(ctx, a, op, false);
4660
- }
4661
-
4662
- struct ggml_tensor * ggml_unary_inplace(
4663
- struct ggml_context * ctx,
4664
- struct ggml_tensor * a,
4665
- enum ggml_unary_op op) {
4666
- return ggml_unary_impl(ctx, a, op, true);
4667
- }
4668
-
4669
- // ggml_map_unary
4670
-
4671
- static struct ggml_tensor * ggml_map_unary_impl_f32(
4672
- struct ggml_context * ctx,
4673
- struct ggml_tensor * a,
4674
- const ggml_unary_op_f32_t fun,
4675
- bool inplace) {
4676
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4677
-
4678
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4679
-
4680
- result->op = GGML_OP_MAP_UNARY;
4681
- result->src[0] = a;
4682
-
4683
- return result;
4684
- }
4685
-
4686
- struct ggml_tensor * ggml_map_unary_f32(
4687
- struct ggml_context * ctx,
4688
- struct ggml_tensor * a,
4689
- const ggml_unary_op_f32_t fun) {
4690
- return ggml_map_unary_impl_f32(ctx, a, fun, false);
4691
- }
4692
-
4693
- struct ggml_tensor * ggml_map_unary_inplace_f32(
4694
- struct ggml_context * ctx,
4695
- struct ggml_tensor * a,
4696
- const ggml_unary_op_f32_t fun) {
4697
- return ggml_map_unary_impl_f32(ctx, a, fun, true);
4698
- }
4699
-
4700
- // ggml_map_binary
4701
-
4702
- static struct ggml_tensor * ggml_map_binary_impl_f32(
4703
- struct ggml_context * ctx,
4704
- struct ggml_tensor * a,
4705
- struct ggml_tensor * b,
4706
- const ggml_binary_op_f32_t fun,
4707
- bool inplace) {
4708
- GGML_ASSERT(ggml_are_same_shape(a, b));
4709
-
4710
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4711
-
4712
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4713
-
4714
- result->op = GGML_OP_MAP_BINARY;
4715
- result->src[0] = a;
4716
- result->src[1] = b;
4717
-
4718
- return result;
4719
- }
4720
-
4721
- struct ggml_tensor * ggml_map_binary_f32(
4722
- struct ggml_context * ctx,
4723
- struct ggml_tensor * a,
4724
- struct ggml_tensor * b,
4725
- const ggml_binary_op_f32_t fun) {
4726
- return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4727
- }
4728
-
4729
- struct ggml_tensor * ggml_map_binary_inplace_f32(
4730
- struct ggml_context * ctx,
4731
- struct ggml_tensor * a,
4732
- struct ggml_tensor * b,
4733
- const ggml_binary_op_f32_t fun) {
4734
- return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4735
- }
5038
+ struct ggml_tensor * k,
5039
+ struct ggml_tensor * v,
5040
+ struct ggml_tensor * q,
5041
+ struct ggml_tensor * g,
5042
+ struct ggml_tensor * state,
5043
+ float scale) {
5044
+ GGML_ASSERT(ggml_is_contiguous(k));
5045
+ GGML_ASSERT(ggml_is_contiguous(v));
5046
+ GGML_ASSERT(ggml_is_contiguous(q));
5047
+ GGML_ASSERT(ggml_is_contiguous(g));
5048
+ GGML_ASSERT(ggml_is_contiguous(state));
4736
5049
 
4737
- // ggml_map_custom1_f32
5050
+ const int64_t S = k->ne[0];
5051
+ const int64_t H = k->ne[1];
5052
+ const int64_t n_tokens = k->ne[2];
5053
+ const int64_t n_seqs = state->ne[1];
5054
+ {
5055
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5056
+ GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
5057
+ GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
5058
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
5059
+ }
4738
5060
 
4739
- static struct ggml_tensor * ggml_map_custom1_impl_f32(
4740
- struct ggml_context * ctx,
4741
- struct ggml_tensor * a,
4742
- const ggml_custom1_op_f32_t fun,
4743
- bool inplace) {
4744
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5061
+ // concat output and new_state
5062
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
5063
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4745
5064
 
4746
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
5065
+ ggml_set_op_params_f32(result, 0, scale);
4747
5066
 
4748
- result->op = GGML_OP_MAP_CUSTOM1_F32;
4749
- result->src[0] = a;
5067
+ result->op = GGML_OP_GATED_LINEAR_ATTN;
5068
+ result->src[0] = k;
5069
+ result->src[1] = v;
5070
+ result->src[2] = q;
5071
+ result->src[3] = g;
5072
+ result->src[4] = state;
4750
5073
 
4751
5074
  return result;
4752
5075
  }
4753
5076
 
4754
- struct ggml_tensor * ggml_map_custom1_f32(
4755
- struct ggml_context * ctx,
4756
- struct ggml_tensor * a,
4757
- const ggml_custom1_op_f32_t fun) {
4758
- return ggml_map_custom1_impl_f32(ctx, a, fun, false);
4759
- }
4760
-
4761
- struct ggml_tensor * ggml_map_custom1_inplace_f32(
4762
- struct ggml_context * ctx,
4763
- struct ggml_tensor * a,
4764
- const ggml_custom1_op_f32_t fun) {
4765
- return ggml_map_custom1_impl_f32(ctx, a, fun, true);
4766
- }
5077
+ // ggml_rwkv_wkv7
4767
5078
 
4768
- // ggml_map_custom2_f32
5079
+ struct ggml_tensor * ggml_rwkv_wkv7(
5080
+ struct ggml_context * ctx,
5081
+ struct ggml_tensor * r,
5082
+ struct ggml_tensor * w,
5083
+ struct ggml_tensor * k,
5084
+ struct ggml_tensor * v,
5085
+ struct ggml_tensor * a,
5086
+ struct ggml_tensor * b,
5087
+ struct ggml_tensor * state) {
5088
+ GGML_ASSERT(ggml_is_contiguous(r));
5089
+ GGML_ASSERT(ggml_is_contiguous(w));
5090
+ GGML_ASSERT(ggml_is_contiguous(k));
5091
+ GGML_ASSERT(ggml_is_contiguous(v));
5092
+ GGML_ASSERT(ggml_is_contiguous(a));
5093
+ GGML_ASSERT(ggml_is_contiguous(b));
5094
+ GGML_ASSERT(ggml_is_contiguous(state));
4769
5095
 
4770
- static struct ggml_tensor * ggml_map_custom2_impl_f32(
4771
- struct ggml_context * ctx,
4772
- struct ggml_tensor * a,
4773
- struct ggml_tensor * b,
4774
- const ggml_custom2_op_f32_t fun,
4775
- bool inplace) {
4776
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5096
+ const int64_t S = k->ne[0];
5097
+ const int64_t H = k->ne[1];
5098
+ const int64_t n_tokens = k->ne[2];
5099
+ const int64_t n_seqs = state->ne[1];
5100
+ {
5101
+ GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
5102
+ GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
5103
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5104
+ GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
5105
+ GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
5106
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
5107
+ }
4777
5108
 
4778
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
5109
+ // concat output and new_state
5110
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
5111
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4779
5112
 
4780
- result->op = GGML_OP_MAP_CUSTOM2_F32;
4781
- result->src[0] = a;
4782
- result->src[1] = b;
5113
+ result->op = GGML_OP_RWKV_WKV7;
5114
+ result->src[0] = r;
5115
+ result->src[1] = w;
5116
+ result->src[2] = k;
5117
+ result->src[3] = v;
5118
+ result->src[4] = a;
5119
+ result->src[5] = b;
5120
+ result->src[6] = state;
4783
5121
 
4784
5122
  return result;
4785
5123
  }
4786
5124
 
4787
- struct ggml_tensor * ggml_map_custom2_f32(
4788
- struct ggml_context * ctx,
4789
- struct ggml_tensor * a,
4790
- struct ggml_tensor * b,
4791
- const ggml_custom2_op_f32_t fun) {
4792
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
4793
- }
4794
-
4795
- struct ggml_tensor * ggml_map_custom2_inplace_f32(
4796
- struct ggml_context * ctx,
4797
- struct ggml_tensor * a,
4798
- struct ggml_tensor * b,
4799
- const ggml_custom2_op_f32_t fun) {
4800
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
4801
- }
5125
+ // ggml_unary
4802
5126
 
4803
- // ggml_map_custom3_f32
5127
+ static struct ggml_tensor * ggml_unary_impl(
5128
+ struct ggml_context * ctx,
5129
+ struct ggml_tensor * a,
5130
+ enum ggml_unary_op op,
5131
+ bool inplace) {
5132
+ GGML_ASSERT(ggml_is_contiguous_1(a));
4804
5133
 
4805
- static struct ggml_tensor * ggml_map_custom3_impl_f32(
4806
- struct ggml_context * ctx,
4807
- struct ggml_tensor * a,
4808
- struct ggml_tensor * b,
4809
- struct ggml_tensor * c,
4810
- const ggml_custom3_op_f32_t fun,
4811
- bool inplace) {
4812
5134
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4813
5135
 
4814
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
5136
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
4815
5137
 
4816
- result->op = GGML_OP_MAP_CUSTOM3_F32;
5138
+ result->op = GGML_OP_UNARY;
4817
5139
  result->src[0] = a;
4818
- result->src[1] = b;
4819
- result->src[2] = c;
4820
5140
 
4821
5141
  return result;
4822
5142
  }
4823
5143
 
4824
- struct ggml_tensor * ggml_map_custom3_f32(
4825
- struct ggml_context * ctx,
4826
- struct ggml_tensor * a,
4827
- struct ggml_tensor * b,
4828
- struct ggml_tensor * c,
4829
- const ggml_custom3_op_f32_t fun) {
4830
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
5144
+ struct ggml_tensor * ggml_unary(
5145
+ struct ggml_context * ctx,
5146
+ struct ggml_tensor * a,
5147
+ enum ggml_unary_op op) {
5148
+ return ggml_unary_impl(ctx, a, op, false);
4831
5149
  }
4832
5150
 
4833
- struct ggml_tensor * ggml_map_custom3_inplace_f32(
4834
- struct ggml_context * ctx,
4835
- struct ggml_tensor * a,
4836
- struct ggml_tensor * b,
4837
- struct ggml_tensor * c,
4838
- const ggml_custom3_op_f32_t fun) {
4839
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
5151
+ struct ggml_tensor * ggml_unary_inplace(
5152
+ struct ggml_context * ctx,
5153
+ struct ggml_tensor * a,
5154
+ enum ggml_unary_op op) {
5155
+ return ggml_unary_impl(ctx, a, op, true);
4840
5156
  }
4841
5157
 
4842
5158
  // ggml_map_custom1
@@ -4857,7 +5173,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
4857
5173
  /*.n_tasks =*/ n_tasks,
4858
5174
  /*.userdata =*/ userdata
4859
5175
  };
4860
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
5176
+ ggml_set_op_params(result, &params, sizeof(params));
4861
5177
 
4862
5178
  result->op = GGML_OP_MAP_CUSTOM1;
4863
5179
  result->src[0] = a;
@@ -4902,7 +5218,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
4902
5218
  /*.n_tasks =*/ n_tasks,
4903
5219
  /*.userdata =*/ userdata
4904
5220
  };
4905
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
5221
+ ggml_set_op_params(result, &params, sizeof(params));
4906
5222
 
4907
5223
  result->op = GGML_OP_MAP_CUSTOM2;
4908
5224
  result->src[0] = a;
@@ -4951,7 +5267,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
4951
5267
  /*.n_tasks =*/ n_tasks,
4952
5268
  /*.userdata =*/ userdata
4953
5269
  };
4954
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
5270
+ ggml_set_op_params(result, &params, sizeof(params));
4955
5271
 
4956
5272
  result->op = GGML_OP_MAP_CUSTOM3;
4957
5273
  result->src[0] = a;
@@ -4983,6 +5299,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
4983
5299
  return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
4984
5300
  }
4985
5301
 
5302
+ struct ggml_tensor * ggml_custom_4d(
5303
+ struct ggml_context * ctx,
5304
+ enum ggml_type type,
5305
+ int64_t ne0,
5306
+ int64_t ne1,
5307
+ int64_t ne2,
5308
+ int64_t ne3,
5309
+ struct ggml_tensor ** args,
5310
+ int n_args,
5311
+ ggml_custom_op_t fun,
5312
+ int n_tasks,
5313
+ void * userdata) {
5314
+
5315
+ GGML_ASSERT(n_args < GGML_MAX_SRC);
5316
+
5317
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
5318
+
5319
+ struct ggml_custom_op_params params = {
5320
+ /*.fun =*/ fun,
5321
+ /*.n_tasks =*/ n_tasks,
5322
+ /*.userdata =*/ userdata
5323
+ };
5324
+ ggml_set_op_params(result, &params, sizeof(params));
5325
+
5326
+ result->op = GGML_OP_CUSTOM;
5327
+ for (int i = 0; i < n_args; i++) {
5328
+ result->src[i] = args[i];
5329
+ }
5330
+
5331
+ return result;
5332
+ }
5333
+
5334
+ struct ggml_tensor * ggml_custom_inplace(
5335
+ struct ggml_context * ctx,
5336
+ struct ggml_tensor * a,
5337
+ struct ggml_tensor ** args,
5338
+ int n_args,
5339
+ ggml_custom_op_t fun,
5340
+ int n_tasks,
5341
+ void * userdata) {
5342
+
5343
+ GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
5344
+
5345
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5346
+
5347
+ struct ggml_custom_op_params params = {
5348
+ /*.fun =*/ fun,
5349
+ /*.n_tasks =*/ n_tasks,
5350
+ /*.userdata =*/ userdata
5351
+ };
5352
+ ggml_set_op_params(result, &params, sizeof(params));
5353
+
5354
+ result->op = GGML_OP_CUSTOM;
5355
+ result->src[0] = a;
5356
+ for (int i = 0; i < n_args; i++) {
5357
+ result->src[i + 1] = args[i];
5358
+ }
5359
+
5360
+ return result;
5361
+ }
4986
5362
  // ggml_cross_entropy_loss
4987
5363
 
4988
5364
  struct ggml_tensor * ggml_cross_entropy_loss(
@@ -5007,10 +5383,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
5007
5383
  struct ggml_tensor * a,
5008
5384
  struct ggml_tensor * b,
5009
5385
  struct ggml_tensor * c) {
5010
- GGML_ASSERT(ggml_are_same_shape(a, b));
5011
- GGML_ASSERT(ggml_is_scalar(c));
5386
+ GGML_ASSERT(ggml_is_scalar(a));
5387
+ GGML_ASSERT(ggml_are_same_shape(b, c));
5012
5388
 
5013
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
5389
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
5014
5390
 
5015
5391
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
5016
5392
  result->src[0] = a;
@@ -5189,7 +5565,7 @@ static void ggml_sub_or_set(
5189
5565
  }
5190
5566
 
5191
5567
  static void ggml_compute_backward(
5192
- struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
5568
+ struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
5193
5569
  struct ggml_tensor * tensor = cgraph->nodes[i];
5194
5570
  struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
5195
5571
 
@@ -5261,7 +5637,7 @@ static void ggml_compute_backward(
5261
5637
  } break;
5262
5638
  case GGML_OP_MUL: {
5263
5639
  if (src0_needs_grads) {
5264
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
5640
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
5265
5641
  }
5266
5642
  if (src1_needs_grads) {
5267
5643
  struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5333,7 +5709,7 @@ static void ggml_compute_backward(
5333
5709
  if (src0_needs_grads) {
5334
5710
  float eps;
5335
5711
  memcpy(&eps, tensor->op_params, sizeof(float));
5336
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
5712
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
5337
5713
  }
5338
5714
  } break;
5339
5715
  case GGML_OP_MUL_MAT: {
@@ -5353,21 +5729,25 @@ static void ggml_compute_backward(
5353
5729
  // src1.shape [n,p,qq,rr]
5354
5730
 
5355
5731
  if (src0_needs_grads) {
5356
- struct ggml_tensor * s1_tg =
5732
+ GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5733
+ GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5734
+ struct ggml_tensor * tmp =
5357
5735
  ggml_out_prod(ctx, // [n,m,qq,rr]
5358
5736
  src1, // [n,p,qq,rr]
5359
5737
  grad); // [m,p,qq,rr]
5360
- const int64_t qq = s1_tg->ne[2];
5361
- const int64_t rr = s1_tg->ne[3];
5362
- const int64_t q1 = src0->ne[2];
5363
- const int64_t r1 = src0->ne[3];
5364
- const bool ne2_broadcasted = qq > q1;
5365
- const bool ne3_broadcasted = rr > r1;
5366
- if (ne2_broadcasted || ne3_broadcasted) {
5367
- // sum broadcast repetitions of s1_tg into shape of src0
5368
- s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5738
+ if (!ggml_are_same_shape(tmp, src0)) {
5739
+ GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5740
+ GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5741
+ GGML_ASSERT(tmp->ne[3] == 1);
5742
+
5743
+ const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5744
+ const size_t nb2 = tmp->nb[2] * nr2;
5745
+ const size_t nb3 = tmp->nb[2];
5746
+
5747
+ tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5748
+ tmp = ggml_repeat_back(ctx, tmp, src0);
5369
5749
  }
5370
- ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5750
+ ggml_add_or_set(ctx, cgraph, isrc0, tmp);
5371
5751
  }
5372
5752
  if (src1_needs_grads) {
5373
5753
  ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5425,7 +5805,7 @@ static void ggml_compute_backward(
5425
5805
  // tensor = src0 * 1 + src1 * 0
5426
5806
  if (src0_needs_grads) {
5427
5807
  // dsrc0 = dtensor * 1
5428
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
5808
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
5429
5809
  }
5430
5810
  if (src1_needs_grads) {
5431
5811
  // dsrc1 = dtensor * 0 -> noop
@@ -5436,7 +5816,9 @@ static void ggml_compute_backward(
5436
5816
  if (src0_needs_grads) {
5437
5817
  GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
5438
5818
  GGML_ASSERT(ggml_is_contiguous(grad));
5439
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
5819
+ GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
5820
+ ggml_add_or_set(ctx, cgraph, isrc0,
5821
+ ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
5440
5822
  }
5441
5823
  } break;
5442
5824
  case GGML_OP_RESHAPE: {
@@ -5516,7 +5898,13 @@ static void ggml_compute_backward(
5516
5898
  } break;
5517
5899
  case GGML_OP_SOFT_MAX: {
5518
5900
  if (src0_needs_grads) {
5519
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
5901
+ float scale = 1.0f;
5902
+ float max_bias = 0.0f;
5903
+
5904
+ memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
5905
+ memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
5906
+
5907
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
5520
5908
  }
5521
5909
  GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5522
5910
  } break;
@@ -5528,6 +5916,7 @@ static void ggml_compute_backward(
5528
5916
  //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5529
5917
  const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5530
5918
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5919
+ int sections[4] = {0, 0, 0, 0};
5531
5920
 
5532
5921
  memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5533
5922
  memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@@ -5535,10 +5924,14 @@ static void ggml_compute_backward(
5535
5924
  memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5536
5925
  memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5537
5926
  memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5538
-
5539
- ggml_add_or_set(ctx, cgraph, isrc0,
5540
- ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
5541
- freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
5927
+ memcpy(&sections, tensor->op_params + 11, sizeof(sections));
5928
+
5929
+ struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
5930
+ ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
5931
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
5932
+ ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
5933
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
5934
+ ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
5542
5935
  }
5543
5936
  GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5544
5937
  } break;
@@ -5552,7 +5945,7 @@ static void ggml_compute_backward(
5552
5945
  const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
5553
5946
  const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
5554
5947
 
5555
- ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5948
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5556
5949
  }
5557
5950
  } break;
5558
5951
  case GGML_OP_POOL_2D: {
@@ -5595,7 +5988,7 @@ static void ggml_compute_backward(
5595
5988
  } break;
5596
5989
  case GGML_UNARY_OP_SILU: {
5597
5990
  if (src0_needs_grads) {
5598
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
5991
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
5599
5992
  }
5600
5993
  } break;
5601
5994
  case GGML_UNARY_OP_EXP: {
@@ -5612,7 +6005,7 @@ static void ggml_compute_backward(
5612
6005
  } break;
5613
6006
  case GGML_OP_CROSS_ENTROPY_LOSS: {
5614
6007
  if (src0_needs_grads) {
5615
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
6008
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
5616
6009
  }
5617
6010
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5618
6011
  } break;
@@ -5631,19 +6024,32 @@ static void ggml_compute_backward(
5631
6024
  GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5632
6025
  }
5633
6026
 
5634
- static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
6027
+ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5635
6028
  // check if already visited
5636
- if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5637
- return;
6029
+ size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
6030
+ GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
6031
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
6032
+ // This is the first time we see this node in the current graph.
6033
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
6034
+ ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
6035
+ cgraph->use_counts[node_hash_pos] = 0;
6036
+ } else {
6037
+ // already visited
6038
+ return node_hash_pos;
5638
6039
  }
5639
6040
 
5640
6041
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
5641
6042
  const int k =
5642
6043
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5643
6044
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5644
- /* unknown order, just fall back to using i*/ i;
5645
- if (node->src[k]) {
5646
- ggml_visit_parents(cgraph, node->src[k]);
6045
+ /* unknown order, just fall back to using i */ i;
6046
+
6047
+ struct ggml_tensor * src = node->src[k];
6048
+ if (src) {
6049
+ size_t src_hash_pos = ggml_visit_parents(cgraph, src);
6050
+
6051
+ // Update the use count for this operand.
6052
+ cgraph->use_counts[src_hash_pos]++;
5647
6053
  }
5648
6054
  }
5649
6055
 
@@ -5667,6 +6073,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
5667
6073
  cgraph->nodes[cgraph->n_nodes] = node;
5668
6074
  cgraph->n_nodes++;
5669
6075
  }
6076
+
6077
+ return node_hash_pos;
5670
6078
  }
5671
6079
 
5672
6080
  static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -5693,10 +6101,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
5693
6101
  }
5694
6102
 
5695
6103
  void ggml_build_backward_expand(
5696
- struct ggml_context * ctx_static,
5697
- struct ggml_context * ctx_compute,
5698
- struct ggml_cgraph * cgraph,
5699
- bool accumulate) {
6104
+ struct ggml_context * ctx,
6105
+ struct ggml_cgraph * cgraph,
6106
+ struct ggml_tensor ** grad_accs) {
5700
6107
  GGML_ASSERT(cgraph->n_nodes > 0);
5701
6108
  GGML_ASSERT(cgraph->grads);
5702
6109
  GGML_ASSERT(cgraph->grad_accs);
@@ -5769,21 +6176,24 @@ void ggml_build_backward_expand(
5769
6176
  GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
5770
6177
  node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
5771
6178
 
5772
- const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5773
- GGML_ASSERT(igrad != GGML_HASHSET_FULL);
5774
- GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5775
- if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
5776
- cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
5777
- cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5778
- ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
6179
+ const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
6180
+ GGML_ASSERT(ihash != GGML_HASHSET_FULL);
6181
+ GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
6182
+ if (grad_accs && grad_accs[i]) {
6183
+ cgraph->grad_accs[ihash] = grad_accs[i];
6184
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
6185
+ } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
6186
+ // loss tensors always need a gradient accumulator
6187
+ cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
6188
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5779
6189
  }
5780
- grads_needed[igrad] = true;
6190
+ grads_needed[ihash] = true;
5781
6191
  }
5782
6192
 
5783
6193
  for (int i = n_nodes_f - 1; i >= 0; --i) {
5784
6194
  // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
5785
6195
  // use allocator to automatically make inplace operations
5786
- ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
6196
+ ggml_compute_backward(ctx, cgraph, i, grads_needed);
5787
6197
  }
5788
6198
 
5789
6199
  free(grads_needed);
@@ -5802,6 +6212,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
5802
6212
  incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5803
6213
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
5804
6214
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6215
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
5805
6216
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
5806
6217
  if (grads) {
5807
6218
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -5831,11 +6242,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5831
6242
 
5832
6243
  void * p = cgraph + 1;
5833
6244
 
5834
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5835
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5836
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5837
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5838
- struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6245
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6246
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6247
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6248
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6249
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6250
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5839
6251
 
5840
6252
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
5841
6253
 
@@ -5850,6 +6262,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
5850
6262
  /*.grads =*/ grads_ptr,
5851
6263
  /*.grad_accs =*/ grad_accs_ptr,
5852
6264
  /*.leafs =*/ leafs_ptr,
6265
+ /*.use_counts =*/ use_counts_ptr,
5853
6266
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
5854
6267
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
5855
6268
  };
@@ -5876,7 +6289,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
5876
6289
  /*.grads =*/ NULL, // gradients would need visited_hash_set
5877
6290
  /*.grad_accs =*/ NULL,
5878
6291
  /*.leafs =*/ NULL,
5879
- /*.visited_hash_set =*/ { 0, NULL, NULL },
6292
+ /*.use_counts =*/ cgraph0->use_counts,
6293
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
5880
6294
  /*.order =*/ cgraph0->order,
5881
6295
  };
5882
6296
 
@@ -5903,7 +6317,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
5903
6317
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
5904
6318
  // copy all hashset keys (tensors) that are in use
5905
6319
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
5906
- ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6320
+ size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6321
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
5907
6322
  }
5908
6323
  }
5909
6324
 
@@ -5929,8 +6344,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
5929
6344
  }
5930
6345
  }
5931
6346
 
5932
- struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
5933
- struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
6347
+ struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
6348
+ struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
5934
6349
  ggml_graph_cpy(cgraph, result);
5935
6350
  return result;
5936
6351
  }
@@ -5949,6 +6364,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
5949
6364
  }
5950
6365
 
5951
6366
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
6367
+ if (!cgraph) {
6368
+ return;
6369
+ }
5952
6370
  GGML_ASSERT(cgraph->grads != NULL);
5953
6371
 
5954
6372
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6258,8 +6676,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
6258
6676
  tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
6259
6677
  }
6260
6678
 
6261
- void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
6262
- GGML_UNUSED(ctx); // TODO: remove this parameter
6679
+ void ggml_set_param(struct ggml_tensor * tensor) {
6680
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
6263
6681
  tensor->flags |= GGML_TENSOR_FLAG_PARAM;
6264
6682
  }
6265
6683
 
@@ -6383,1288 +6801,6 @@ size_t ggml_quantize_chunk(
6383
6801
 
6384
6802
  ////////////////////////////////////////////////////////////////////////////////
6385
6803
 
6386
- struct gguf_str {
6387
- uint64_t n; // GGUFv2
6388
- char * data;
6389
- };
6390
-
6391
- static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
6392
- [GGUF_TYPE_UINT8] = sizeof(uint8_t),
6393
- [GGUF_TYPE_INT8] = sizeof(int8_t),
6394
- [GGUF_TYPE_UINT16] = sizeof(uint16_t),
6395
- [GGUF_TYPE_INT16] = sizeof(int16_t),
6396
- [GGUF_TYPE_UINT32] = sizeof(uint32_t),
6397
- [GGUF_TYPE_INT32] = sizeof(int32_t),
6398
- [GGUF_TYPE_FLOAT32] = sizeof(float),
6399
- [GGUF_TYPE_BOOL] = sizeof(bool),
6400
- [GGUF_TYPE_STRING] = sizeof(struct gguf_str),
6401
- [GGUF_TYPE_UINT64] = sizeof(uint64_t),
6402
- [GGUF_TYPE_INT64] = sizeof(int64_t),
6403
- [GGUF_TYPE_FLOAT64] = sizeof(double),
6404
- [GGUF_TYPE_ARRAY] = 0, // undefined
6405
- };
6406
- static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
6407
-
6408
- static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
6409
- [GGUF_TYPE_UINT8] = "u8",
6410
- [GGUF_TYPE_INT8] = "i8",
6411
- [GGUF_TYPE_UINT16] = "u16",
6412
- [GGUF_TYPE_INT16] = "i16",
6413
- [GGUF_TYPE_UINT32] = "u32",
6414
- [GGUF_TYPE_INT32] = "i32",
6415
- [GGUF_TYPE_FLOAT32] = "f32",
6416
- [GGUF_TYPE_BOOL] = "bool",
6417
- [GGUF_TYPE_STRING] = "str",
6418
- [GGUF_TYPE_ARRAY] = "arr",
6419
- [GGUF_TYPE_UINT64] = "u64",
6420
- [GGUF_TYPE_INT64] = "i64",
6421
- [GGUF_TYPE_FLOAT64] = "f64",
6422
- };
6423
- static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
6424
-
6425
- union gguf_value {
6426
- uint8_t uint8;
6427
- int8_t int8;
6428
- uint16_t uint16;
6429
- int16_t int16;
6430
- uint32_t uint32;
6431
- int32_t int32;
6432
- float float32;
6433
- uint64_t uint64;
6434
- int64_t int64;
6435
- double float64;
6436
- bool bool_;
6437
-
6438
- struct gguf_str str;
6439
-
6440
- struct {
6441
- enum gguf_type type;
6442
-
6443
- uint64_t n; // GGUFv2
6444
- void * data;
6445
- } arr;
6446
- };
6447
-
6448
- struct gguf_kv {
6449
- struct gguf_str key;
6450
-
6451
- enum gguf_type type;
6452
- union gguf_value value;
6453
- };
6454
-
6455
- struct gguf_header {
6456
- char magic[4];
6457
-
6458
- uint32_t version;
6459
- uint64_t n_tensors; // GGUFv2
6460
- uint64_t n_kv; // GGUFv2
6461
- };
6462
-
6463
- struct gguf_tensor_info {
6464
- struct gguf_str name;
6465
-
6466
- uint32_t n_dims;
6467
- uint64_t ne[GGML_MAX_DIMS];
6468
-
6469
- enum ggml_type type;
6470
-
6471
- uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
6472
-
6473
- // for writing API
6474
- const void * data;
6475
- size_t size;
6476
- };
6477
-
6478
- struct gguf_context {
6479
- struct gguf_header header;
6480
-
6481
- struct gguf_kv * kv;
6482
- struct gguf_tensor_info * infos;
6483
-
6484
- size_t alignment;
6485
- size_t offset; // offset of `data` from beginning of file
6486
- size_t size; // size of `data` in bytes
6487
-
6488
- //uint8_t * padding;
6489
- void * data;
6490
- };
6491
-
6492
- static size_t gguf_type_size(enum gguf_type type) {
6493
- GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
6494
- return GGUF_TYPE_SIZE[type];
6495
- }
6496
-
6497
- static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
6498
- if (info->n_dims > GGML_MAX_DIMS) {
6499
- fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
6500
- return false;
6501
- }
6502
-
6503
- if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
6504
- fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
6505
- return false;
6506
- }
6507
-
6508
- if (strlen(info->name.data) >= GGML_MAX_NAME) {
6509
- fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
6510
- return false;
6511
- }
6512
-
6513
- for (uint32_t i = 0; i < info->n_dims; ++i) {
6514
- if (info->ne[i] <= 0) {
6515
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
6516
- return false;
6517
- }
6518
- }
6519
-
6520
- // prevent overflow for total number of elements
6521
- if (INT64_MAX/info->ne[1] <= info->ne[0]) {
6522
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
6523
- return false;
6524
- }
6525
-
6526
- if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
6527
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
6528
- return false;
6529
- }
6530
-
6531
- if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
6532
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
6533
- return false;
6534
- }
6535
-
6536
- return true;
6537
- }
6538
-
6539
- static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
6540
- const size_t n = fread(dst, 1, size, file);
6541
- *offset += n;
6542
- return n == size;
6543
- }
6544
-
6545
- static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
6546
- p->n = 0;
6547
- p->data = NULL;
6548
-
6549
- bool ok = true;
6550
-
6551
- ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
6552
-
6553
- // early exit if string length is invalid, prevents from integer overflow
6554
- if (p->n == SIZE_MAX) {
6555
- fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
6556
- return false;
6557
- }
6558
-
6559
- p->data = calloc(p->n + 1, 1);
6560
- if (!p->data) {
6561
- fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
6562
- return false;
6563
- }
6564
-
6565
- ok = ok && gguf_fread_el(file, p->data, p->n, offset);
6566
-
6567
- return ok;
6568
- }
6569
-
6570
- static void gguf_free_kv(struct gguf_kv * kv) {
6571
- if (kv->key.data) {
6572
- GGML_FREE(kv->key.data);
6573
- }
6574
-
6575
- if (kv->type == GGUF_TYPE_STRING) {
6576
- if (kv->value.str.data) {
6577
- GGML_FREE(kv->value.str.data);
6578
- }
6579
- }
6580
-
6581
- if (kv->type == GGUF_TYPE_ARRAY) {
6582
- if (kv->value.arr.data) {
6583
- if (kv->value.arr.type == GGUF_TYPE_STRING) {
6584
- for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
6585
- struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
6586
- if (str->data) {
6587
- GGML_FREE(str->data);
6588
- }
6589
- }
6590
- }
6591
- GGML_FREE(kv->value.arr.data);
6592
- }
6593
- }
6594
- }
6595
-
6596
- struct gguf_context * gguf_init_empty(void) {
6597
- struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
6598
- if (!ctx) {
6599
- fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
6600
- return NULL;
6601
- }
6602
-
6603
- memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
6604
- ctx->header.version = GGUF_VERSION;
6605
- ctx->header.n_tensors = 0;
6606
- ctx->header.n_kv = 0;
6607
-
6608
- ctx->kv = NULL;
6609
- ctx->infos = NULL;
6610
-
6611
- ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
6612
- ctx->offset = 0;
6613
- ctx->size = 0;
6614
-
6615
- ctx->data = NULL;
6616
-
6617
- return ctx;
6618
- }
6619
-
6620
- struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
6621
- FILE * file = ggml_fopen(fname, "rb");
6622
- if (!file) {
6623
- fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
6624
- return NULL;
6625
- }
6626
-
6627
- // offset from start of file
6628
- size_t offset = 0;
6629
-
6630
- char magic[4];
6631
-
6632
- // check the magic before making allocations
6633
- {
6634
- gguf_fread_el(file, &magic, sizeof(magic), &offset);
6635
-
6636
- for (uint32_t i = 0; i < sizeof(magic); i++) {
6637
- if (magic[i] != GGUF_MAGIC[i]) {
6638
- fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
6639
- fclose(file);
6640
- return NULL;
6641
- }
6642
- }
6643
- }
6644
-
6645
- bool ok = true;
6646
-
6647
- struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
6648
- if (!ctx) {
6649
- fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
6650
- fclose(file);
6651
- return NULL;
6652
- }
6653
-
6654
- // read the header
6655
- {
6656
- strncpy(ctx->header.magic, magic, 4);
6657
-
6658
- ctx->kv = NULL;
6659
- ctx->infos = NULL;
6660
- ctx->data = NULL;
6661
-
6662
- ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
6663
- ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
6664
- ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
6665
-
6666
- if (ctx->header.version == 1) {
6667
- fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
6668
- fclose(file);
6669
- gguf_free(ctx);
6670
- return NULL;
6671
- }
6672
-
6673
- // sanity-checks to prevent from integer/buffer overflows
6674
-
6675
- ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
6676
- ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
6677
- ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct gguf_kv));
6678
-
6679
- if (!ok) {
6680
- fprintf(stderr, "%s: failed to read header\n", __func__);
6681
- fclose(file);
6682
- gguf_free(ctx);
6683
- return NULL;
6684
- }
6685
- }
6686
-
6687
- // read the kv pairs
6688
- {
6689
- const uint64_t n_kv = ctx->header.n_kv;
6690
-
6691
- ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
6692
- if (!ctx->kv) {
6693
- fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
6694
- fclose(file);
6695
- gguf_free(ctx);
6696
- return NULL;
6697
- }
6698
-
6699
- for (uint64_t i = 0; i < n_kv; ++i) {
6700
- struct gguf_kv * kv = &ctx->kv[i];
6701
-
6702
- //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
6703
-
6704
- ok = ok && gguf_fread_str(file, &kv->key, &offset);
6705
- ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
6706
-
6707
- //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
6708
-
6709
- switch (kv->type) {
6710
- case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break;
6711
- case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break;
6712
- case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break;
6713
- case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break;
6714
- case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break;
6715
- case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break;
6716
- case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
6717
- case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break;
6718
- case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break;
6719
- case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
6720
- case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break;
6721
- case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break;
6722
- case GGUF_TYPE_ARRAY:
6723
- {
6724
- ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
6725
- ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
6726
-
6727
- switch (kv->value.arr.type) {
6728
- case GGUF_TYPE_UINT8:
6729
- case GGUF_TYPE_INT8:
6730
- case GGUF_TYPE_UINT16:
6731
- case GGUF_TYPE_INT16:
6732
- case GGUF_TYPE_UINT32:
6733
- case GGUF_TYPE_INT32:
6734
- case GGUF_TYPE_FLOAT32:
6735
- case GGUF_TYPE_UINT64:
6736
- case GGUF_TYPE_INT64:
6737
- case GGUF_TYPE_FLOAT64:
6738
- case GGUF_TYPE_BOOL:
6739
- {
6740
- // prevent from integer overflow in the malloc below
6741
- if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
6742
- fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
6743
- fclose(file);
6744
- gguf_free(ctx);
6745
- return NULL;
6746
- }
6747
-
6748
- kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
6749
- if (!kv->value.arr.data) {
6750
- fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
6751
- fclose(file);
6752
- gguf_free(ctx);
6753
- return NULL;
6754
- }
6755
-
6756
- ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
6757
- } break;
6758
- case GGUF_TYPE_STRING:
6759
- {
6760
- // prevent from integer overflow in the malloc below
6761
- if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
6762
- fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
6763
- fclose(file);
6764
- gguf_free(ctx);
6765
- return NULL;
6766
- }
6767
-
6768
- kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
6769
- if (!kv->value.arr.data) {
6770
- fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
6771
- fclose(file);
6772
- gguf_free(ctx);
6773
- return NULL;
6774
- }
6775
-
6776
- for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
6777
- ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
6778
- }
6779
- } break;
6780
- case GGUF_TYPE_ARRAY:
6781
- default:
6782
- {
6783
- fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
6784
- ok = false;
6785
- } break;
6786
- }
6787
- } break;
6788
- default:
6789
- {
6790
- fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
6791
- ok = false;
6792
- } break;
6793
- }
6794
-
6795
- if (!ok) {
6796
- break;
6797
- }
6798
- }
6799
-
6800
- if (!ok) {
6801
- fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
6802
- fclose(file);
6803
- gguf_free(ctx);
6804
- return NULL;
6805
- }
6806
- }
6807
-
6808
- // read the tensor infos
6809
- if (ctx->header.n_tensors > 0) {
6810
- ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
6811
- if (!ctx->infos) {
6812
- fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
6813
- fclose(file);
6814
- gguf_free(ctx);
6815
- return NULL;
6816
- }
6817
-
6818
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6819
- struct gguf_tensor_info * info = &ctx->infos[i];
6820
-
6821
- for (int j = 0; j < GGML_MAX_DIMS; ++j) {
6822
- info->ne[j] = 1;
6823
- }
6824
-
6825
- ok = ok && gguf_fread_str(file, &info->name, &offset);
6826
- ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
6827
-
6828
- ok = ok && (info->n_dims <= GGML_MAX_DIMS);
6829
-
6830
- for (uint32_t j = 0; j < info->n_dims; ++j) {
6831
- ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
6832
- }
6833
-
6834
- ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
6835
- ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
6836
-
6837
- ok = ok && gguf_tensor_info_sanitize(info);
6838
-
6839
- // make sure there is no duplicated tensor names
6840
- for (uint64_t j = 0; j < i && ok; ++j) {
6841
- if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
6842
- fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
6843
- ok = false;
6844
- }
6845
- }
6846
-
6847
- if (!ok) {
6848
- fprintf(stderr, "%s: failed to read tensor info\n", __func__);
6849
- fclose(file);
6850
- gguf_free(ctx);
6851
- return NULL;
6852
- }
6853
- }
6854
- }
6855
-
6856
- ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
6857
-
6858
- int alignment_idx = gguf_find_key(ctx, "general.alignment");
6859
- if (alignment_idx != -1) {
6860
- ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
6861
- }
6862
-
6863
- // we require the data section to be aligned, so take into account any padding
6864
- {
6865
- const size_t offset_pad = offset % ctx->alignment;
6866
-
6867
- if (offset_pad != 0) {
6868
- offset += ctx->alignment - offset_pad;
6869
- fseek(file, offset, SEEK_SET);
6870
- }
6871
- }
6872
-
6873
- // store the current file offset - this is where the data section starts
6874
- ctx->offset = offset;
6875
-
6876
- // compute the total size of the data section, taking into account the alignment
6877
- {
6878
- ctx->size = 0;
6879
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6880
- struct gguf_tensor_info * info = &ctx->infos[i];
6881
-
6882
- const int64_t ne =
6883
- (int64_t) info->ne[0] *
6884
- (int64_t) info->ne[1] *
6885
- (int64_t) info->ne[2] *
6886
- (int64_t) info->ne[3];
6887
-
6888
- if (ggml_blck_size(info->type) == 0 ) {
6889
- // this tensor type support have been removed:
6890
- fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
6891
- __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
6892
- fclose(file);
6893
- gguf_free(ctx);
6894
- return NULL;
6895
- }
6896
-
6897
- if (ne % ggml_blck_size(info->type) != 0) {
6898
- fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
6899
- __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
6900
- fclose(file);
6901
- gguf_free(ctx);
6902
- return NULL;
6903
- }
6904
-
6905
- const size_t size_cur = ggml_row_size(info->type, ne);
6906
-
6907
- ctx->size += GGML_PAD(size_cur, ctx->alignment);
6908
- }
6909
- }
6910
-
6911
- // load the tensor data only if requested
6912
- if (params.ctx != NULL) {
6913
- // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
6914
- // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
6915
- // the ggml_tensor structs to the appropriate locations in the binary blob
6916
-
6917
- // compute the exact size needed for the new ggml_context
6918
- const size_t mem_size =
6919
- params.no_alloc ?
6920
- (ctx->header.n_tensors )*ggml_tensor_overhead() :
6921
- (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
6922
-
6923
- struct ggml_init_params pdata = {
6924
- .mem_size = mem_size,
6925
- .mem_buffer = NULL,
6926
- .no_alloc = params.no_alloc,
6927
- };
6928
-
6929
- *params.ctx = ggml_init(pdata);
6930
- if (*params.ctx == NULL) {
6931
- fprintf(stderr, "%s: failed to initialize context\n", __func__);
6932
- fclose(file);
6933
- gguf_free(ctx);
6934
- return NULL;
6935
- }
6936
-
6937
- struct ggml_context * ctx_data = *params.ctx;
6938
-
6939
- struct ggml_tensor * data = NULL;
6940
-
6941
- if (!params.no_alloc) {
6942
- data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
6943
-
6944
- ok = ok && data != NULL;
6945
-
6946
- // read the binary blob with the tensor data
6947
- ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
6948
-
6949
- if (!ok) {
6950
- fprintf(stderr, "%s: failed to read tensor data\n", __func__);
6951
- fclose(file);
6952
- ggml_free(ctx_data);
6953
- gguf_free(ctx);
6954
- return NULL;
6955
- }
6956
-
6957
- ctx->data = data->data;
6958
- }
6959
-
6960
- ggml_set_no_alloc(ctx_data, true);
6961
-
6962
- // create the tensors
6963
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6964
- const int64_t ne[GGML_MAX_DIMS] = {
6965
- ctx->infos[i].ne[0],
6966
- ctx->infos[i].ne[1],
6967
- ctx->infos[i].ne[2],
6968
- ctx->infos[i].ne[3],
6969
- };
6970
-
6971
- struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
6972
-
6973
- ok = ok && cur != NULL;
6974
-
6975
- if (!ok) {
6976
- break;
6977
- }
6978
-
6979
- ggml_set_name(cur, ctx->infos[i].name.data);
6980
-
6981
- // point the data member to the appropriate location in the binary blob using the tensor infos
6982
- if (!params.no_alloc) {
6983
- //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
6984
- cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
6985
- }
6986
- }
6987
-
6988
- if (!ok) {
6989
- fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
6990
- fclose(file);
6991
- ggml_free(ctx_data);
6992
- gguf_free(ctx);
6993
- return NULL;
6994
- }
6995
-
6996
- ggml_set_no_alloc(ctx_data, params.no_alloc);
6997
- }
6998
-
6999
- fclose(file);
7000
-
7001
- return ctx;
7002
- }
7003
-
7004
- void gguf_free(struct gguf_context * ctx) {
7005
- if (ctx == NULL) {
7006
- return;
7007
- }
7008
-
7009
- if (ctx->kv) {
7010
- // free string memory - not great..
7011
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
7012
- gguf_free_kv(&ctx->kv[i]);
7013
- }
7014
-
7015
- GGML_FREE(ctx->kv);
7016
- }
7017
-
7018
- if (ctx->infos) {
7019
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
7020
- struct gguf_tensor_info * info = &ctx->infos[i];
7021
-
7022
- if (info->name.data) {
7023
- GGML_FREE(info->name.data);
7024
- }
7025
- }
7026
-
7027
- GGML_FREE(ctx->infos);
7028
- }
7029
-
7030
- GGML_FREE(ctx);
7031
- }
7032
-
7033
- const char * gguf_type_name(enum gguf_type type) {
7034
- return GGUF_TYPE_NAME[type];
7035
- }
7036
-
7037
- int gguf_get_version(const struct gguf_context * ctx) {
7038
- return ctx->header.version;
7039
- }
7040
-
7041
- size_t gguf_get_alignment(const struct gguf_context * ctx) {
7042
- return ctx->alignment;
7043
- }
7044
-
7045
- size_t gguf_get_data_offset(const struct gguf_context * ctx) {
7046
- return ctx->offset;
7047
- }
7048
-
7049
- void * gguf_get_data(const struct gguf_context * ctx) {
7050
- return ctx->data;
7051
- }
7052
-
7053
- int gguf_get_n_kv(const struct gguf_context * ctx) {
7054
- return ctx->header.n_kv;
7055
- }
7056
-
7057
- int gguf_find_key(const struct gguf_context * ctx, const char * key) {
7058
- // return -1 if key not found
7059
- int keyfound = -1;
7060
-
7061
- const int n_kv = gguf_get_n_kv(ctx);
7062
-
7063
- for (int i = 0; i < n_kv; ++i) {
7064
- if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
7065
- keyfound = i;
7066
- break;
7067
- }
7068
- }
7069
-
7070
- return keyfound;
7071
- }
7072
-
7073
- const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
7074
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7075
- return ctx->kv[key_id].key.data;
7076
- }
7077
-
7078
- enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
7079
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7080
- return ctx->kv[key_id].type;
7081
- }
7082
-
7083
- enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
7084
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7085
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7086
- return ctx->kv[key_id].value.arr.type;
7087
- }
7088
-
7089
- const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
7090
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7091
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7092
- return ctx->kv[key_id].value.arr.data;
7093
- }
7094
-
7095
- const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
7096
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7097
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7098
- struct gguf_kv * kv = &ctx->kv[key_id];
7099
- struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
7100
- return str->data;
7101
- }
7102
-
7103
- int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
7104
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7105
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7106
- return ctx->kv[key_id].value.arr.n;
7107
- }
7108
-
7109
- uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
7110
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7111
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
7112
- return ctx->kv[key_id].value.uint8;
7113
- }
7114
-
7115
- int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
7116
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7117
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
7118
- return ctx->kv[key_id].value.int8;
7119
- }
7120
-
7121
- uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
7122
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7123
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
7124
- return ctx->kv[key_id].value.uint16;
7125
- }
7126
-
7127
- int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
7128
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7129
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
7130
- return ctx->kv[key_id].value.int16;
7131
- }
7132
-
7133
- uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
7134
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7135
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
7136
- return ctx->kv[key_id].value.uint32;
7137
- }
7138
-
7139
- int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
7140
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7141
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
7142
- return ctx->kv[key_id].value.int32;
7143
- }
7144
-
7145
- float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
7146
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7147
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
7148
- return ctx->kv[key_id].value.float32;
7149
- }
7150
-
7151
- uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
7152
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7153
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
7154
- return ctx->kv[key_id].value.uint64;
7155
- }
7156
-
7157
- int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
7158
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7159
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
7160
- return ctx->kv[key_id].value.int64;
7161
- }
7162
-
7163
- double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
7164
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7165
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
7166
- return ctx->kv[key_id].value.float64;
7167
- }
7168
-
7169
- bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
7170
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7171
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
7172
- return ctx->kv[key_id].value.bool_;
7173
- }
7174
-
7175
- const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
7176
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7177
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
7178
- return ctx->kv[key_id].value.str.data;
7179
- }
7180
-
7181
- const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
7182
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7183
- GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
7184
- GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
7185
- return &ctx->kv[key_id].value;
7186
- }
7187
-
7188
- int gguf_get_n_tensors(const struct gguf_context * ctx) {
7189
- return ctx->header.n_tensors;
7190
- }
7191
-
7192
- int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
7193
- // return -1 if tensor not found
7194
- int tensorfound = -1;
7195
-
7196
- const int n_tensors = gguf_get_n_tensors(ctx);
7197
-
7198
- for (int i = 0; i < n_tensors; ++i) {
7199
- if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
7200
- tensorfound = i;
7201
- break;
7202
- }
7203
- }
7204
-
7205
- return tensorfound;
7206
- }
7207
-
7208
- size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
7209
- return ctx->infos[i].offset;
7210
- }
7211
-
7212
- char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
7213
- return ctx->infos[i].name.data;
7214
- }
7215
-
7216
- enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
7217
- return ctx->infos[i].type;
7218
- }
7219
-
7220
- // returns the index
7221
- static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
7222
- const int idx = gguf_find_key(ctx, key);
7223
- if (idx >= 0) {
7224
- return idx;
7225
- }
7226
-
7227
- const int n_kv = gguf_get_n_kv(ctx);
7228
-
7229
- ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
7230
- ctx->kv[n_kv].key.n = strlen(key);
7231
- ctx->kv[n_kv].key.data = strdup(key);
7232
- ctx->header.n_kv++;
7233
-
7234
- return n_kv;
7235
- }
7236
-
7237
- void gguf_remove_key(struct gguf_context * ctx, const char * key) {
7238
- const int idx = gguf_find_key(ctx, key);
7239
- if (idx >= 0) {
7240
- const int n_kv = gguf_get_n_kv(ctx);
7241
- gguf_free_kv(&ctx->kv[idx]);
7242
- for (int i = idx; i < n_kv-1; ++i) {
7243
- ctx->kv[i] = ctx->kv[i+1];
7244
- }
7245
- ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
7246
- ctx->header.n_kv--;
7247
- }
7248
- }
7249
-
7250
- void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
7251
- const int idx = gguf_get_or_add_key(ctx, key);
7252
-
7253
- ctx->kv[idx].type = GGUF_TYPE_UINT8;
7254
- ctx->kv[idx].value.uint8 = val;
7255
- }
7256
-
7257
- void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
7258
- const int idx = gguf_get_or_add_key(ctx, key);
7259
-
7260
- ctx->kv[idx].type = GGUF_TYPE_INT8;
7261
- ctx->kv[idx].value.int8 = val;
7262
- }
7263
-
7264
- void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
7265
- const int idx = gguf_get_or_add_key(ctx, key);
7266
-
7267
- ctx->kv[idx].type = GGUF_TYPE_UINT16;
7268
- ctx->kv[idx].value.uint16 = val;
7269
- }
7270
-
7271
- void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
7272
- const int idx = gguf_get_or_add_key(ctx, key);
7273
-
7274
- ctx->kv[idx].type = GGUF_TYPE_INT16;
7275
- ctx->kv[idx].value.int16 = val;
7276
- }
7277
-
7278
- void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
7279
- const int idx = gguf_get_or_add_key(ctx, key);
7280
-
7281
- ctx->kv[idx].type = GGUF_TYPE_UINT32;
7282
- ctx->kv[idx].value.uint32 = val;
7283
- }
7284
-
7285
- void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
7286
- const int idx = gguf_get_or_add_key(ctx, key);
7287
-
7288
- ctx->kv[idx].type = GGUF_TYPE_INT32;
7289
- ctx->kv[idx].value.int32 = val;
7290
- }
7291
-
7292
- void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
7293
- const int idx = gguf_get_or_add_key(ctx, key);
7294
-
7295
- ctx->kv[idx].type = GGUF_TYPE_FLOAT32;
7296
- ctx->kv[idx].value.float32 = val;
7297
- }
7298
-
7299
- void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
7300
- const int idx = gguf_get_or_add_key(ctx, key);
7301
-
7302
- ctx->kv[idx].type = GGUF_TYPE_UINT64;
7303
- ctx->kv[idx].value.uint64 = val;
7304
- }
7305
-
7306
- void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
7307
- const int idx = gguf_get_or_add_key(ctx, key);
7308
-
7309
- ctx->kv[idx].type = GGUF_TYPE_INT64;
7310
- ctx->kv[idx].value.int64 = val;
7311
- }
7312
-
7313
- void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
7314
- const int idx = gguf_get_or_add_key(ctx, key);
7315
-
7316
- ctx->kv[idx].type = GGUF_TYPE_FLOAT64;
7317
- ctx->kv[idx].value.float64 = val;
7318
- }
7319
-
7320
- void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
7321
- const int idx = gguf_get_or_add_key(ctx, key);
7322
-
7323
- ctx->kv[idx].type = GGUF_TYPE_BOOL;
7324
- ctx->kv[idx].value.bool_ = val;
7325
- }
7326
-
7327
- void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
7328
- const int idx = gguf_get_or_add_key(ctx, key);
7329
-
7330
- ctx->kv[idx].type = GGUF_TYPE_STRING;
7331
- ctx->kv[idx].value.str.n = strlen(val);
7332
- ctx->kv[idx].value.str.data = strdup(val);
7333
- }
7334
-
7335
- void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
7336
- const int idx = gguf_get_or_add_key(ctx, key);
7337
-
7338
- ctx->kv[idx].type = GGUF_TYPE_ARRAY;
7339
- ctx->kv[idx].value.arr.type = type;
7340
- ctx->kv[idx].value.arr.n = n;
7341
- ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
7342
- memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
7343
- }
7344
-
7345
- void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
7346
- const int idx = gguf_get_or_add_key(ctx, key);
7347
-
7348
- ctx->kv[idx].type = GGUF_TYPE_ARRAY;
7349
- ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
7350
- ctx->kv[idx].value.arr.n = n;
7351
- ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
7352
- for (int i = 0; i < n; i++) {
7353
- struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
7354
- str->n = strlen(data[i]);
7355
- str->data = strdup(data[i]);
7356
- }
7357
- }
7358
-
7359
- // set or add KV pairs from another context
7360
- void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
7361
- for (uint32_t i = 0; i < src->header.n_kv; i++) {
7362
- switch (src->kv[i].type) {
7363
- case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break;
7364
- case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break;
7365
- case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break;
7366
- case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break;
7367
- case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break;
7368
- case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break;
7369
- case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break;
7370
- case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break;
7371
- case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break;
7372
- case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break;
7373
- case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break;
7374
- case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
7375
- case GGUF_TYPE_ARRAY:
7376
- {
7377
- if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
7378
- const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
7379
- for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
7380
- data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
7381
- }
7382
- gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
7383
- GGML_FREE((void *)data);
7384
- } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
7385
- GGML_ABORT("nested arrays not supported");
7386
- } else {
7387
- gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
7388
- }
7389
- } break;
7390
- default: GGML_ABORT("invalid type");
7391
- }
7392
- }
7393
- }
7394
-
7395
- void gguf_add_tensor(
7396
- struct gguf_context * ctx,
7397
- const struct ggml_tensor * tensor) {
7398
- GGML_ASSERT(tensor);
7399
- if (gguf_find_tensor(ctx, tensor->name) != -1) {
7400
- GGML_ABORT("duplicated tensor name");
7401
- }
7402
-
7403
- const int idx = ctx->header.n_tensors;
7404
- ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
7405
-
7406
- ctx->infos[idx].name.n = strlen(tensor->name);
7407
- ctx->infos[idx].name.data = strdup(tensor->name);
7408
-
7409
- for (int i = 0; i < GGML_MAX_DIMS; ++i) {
7410
- ctx->infos[idx].ne[i] = 1;
7411
- }
7412
-
7413
- ctx->infos[idx].n_dims = ggml_n_dims(tensor);
7414
- for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
7415
- ctx->infos[idx].ne[i] = tensor->ne[i];
7416
- }
7417
-
7418
- ctx->infos[idx].type = tensor->type;
7419
- ctx->infos[idx].offset = 0;
7420
- ctx->infos[idx].data = tensor->data;
7421
- ctx->infos[idx].size = ggml_nbytes(tensor);
7422
-
7423
- if (ctx->header.n_tensors > 0) {
7424
- ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
7425
- }
7426
-
7427
- ctx->header.n_tensors++;
7428
- }
7429
-
7430
- void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
7431
- const int idx = gguf_find_tensor(ctx, name);
7432
- if (idx < 0) {
7433
- GGML_ABORT("tensor not found");
7434
- }
7435
-
7436
- ctx->infos[idx].type = type;
7437
- }
7438
-
7439
- void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
7440
- const int idx = gguf_find_tensor(ctx, name);
7441
- if (idx < 0) {
7442
- GGML_ABORT("tensor not found");
7443
- }
7444
-
7445
- ctx->infos[idx].data = data;
7446
- ctx->infos[idx].size = size;
7447
-
7448
- // update offsets
7449
- for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
7450
- ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
7451
- }
7452
- }
7453
-
7454
- //static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
7455
- // fwrite(&val->n, sizeof(val->n), 1, file);
7456
- // fwrite(val->data, sizeof(char), val->n, file);
7457
- //}
7458
- //
7459
- //static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
7460
- // fwrite(val, sizeof(char), size, file);
7461
- //}
7462
-
7463
- struct gguf_buf {
7464
- void * data;
7465
- size_t size;
7466
- size_t offset;
7467
- };
7468
-
7469
- static struct gguf_buf gguf_buf_init(size_t size) {
7470
- struct gguf_buf buf = {
7471
- /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
7472
- /*buf.size =*/ size,
7473
- /*buf.offset =*/ 0,
7474
- };
7475
-
7476
- return buf;
7477
- }
7478
-
7479
- static void gguf_buf_free(struct gguf_buf buf) {
7480
- if (buf.data) {
7481
- GGML_FREE(buf.data);
7482
- }
7483
- }
7484
-
7485
- static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
7486
- if (buf->offset + size > buf->size) {
7487
- buf->size = 1.5*(buf->offset + size);
7488
- if (buf->data) {
7489
- buf->data = realloc(buf->data, buf->size);
7490
- }
7491
- }
7492
- }
7493
-
7494
- static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
7495
- gguf_buf_grow(buf, sizeof(val->n) + val->n);
7496
-
7497
- if (buf->data) {
7498
- memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
7499
- }
7500
- buf->offset += sizeof(val->n);
7501
-
7502
- if (buf->data) {
7503
- memcpy((char *) buf->data + buf->offset, val->data, val->n);
7504
- }
7505
- buf->offset += val->n;
7506
- }
7507
-
7508
- static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
7509
- gguf_buf_grow(buf, el_size);
7510
-
7511
- if (buf->data) {
7512
- memcpy((char *) buf->data + buf->offset, val, el_size);
7513
- }
7514
- buf->offset += el_size;
7515
- }
7516
-
7517
- static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
7518
- // write header
7519
- gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
7520
- gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
7521
- gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
7522
- gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv));
7523
-
7524
- // write key-value pairs
7525
- for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
7526
- struct gguf_kv * kv = &ctx->kv[i];
7527
-
7528
- gguf_bwrite_str(buf, &kv->key);
7529
- gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
7530
-
7531
- switch (kv->type) {
7532
- case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break;
7533
- case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break;
7534
- case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break;
7535
- case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break;
7536
- case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break;
7537
- case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break;
7538
- case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
7539
- case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break;
7540
- case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break;
7541
- case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
7542
- case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break;
7543
- case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break;
7544
- case GGUF_TYPE_ARRAY:
7545
- {
7546
- gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
7547
- gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) );
7548
-
7549
- switch (kv->value.arr.type) {
7550
- case GGUF_TYPE_UINT8:
7551
- case GGUF_TYPE_INT8:
7552
- case GGUF_TYPE_UINT16:
7553
- case GGUF_TYPE_INT16:
7554
- case GGUF_TYPE_UINT32:
7555
- case GGUF_TYPE_INT32:
7556
- case GGUF_TYPE_FLOAT32:
7557
- case GGUF_TYPE_UINT64:
7558
- case GGUF_TYPE_INT64:
7559
- case GGUF_TYPE_FLOAT64:
7560
- case GGUF_TYPE_BOOL:
7561
- {
7562
- gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
7563
- } break;
7564
- case GGUF_TYPE_STRING:
7565
- {
7566
- for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
7567
- gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
7568
- }
7569
- } break;
7570
- case GGUF_TYPE_ARRAY:
7571
- default: GGML_ABORT("invalid type");
7572
- }
7573
- } break;
7574
- default: GGML_ABORT("invalid type");
7575
- }
7576
- }
7577
-
7578
- // write tensor infos
7579
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
7580
- struct gguf_tensor_info * info = &ctx->infos[i];
7581
-
7582
- gguf_bwrite_str(buf, &info->name);
7583
- gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
7584
- for (uint32_t j = 0; j < info->n_dims; ++j) {
7585
- gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
7586
- }
7587
- gguf_bwrite_el(buf, &info->type, sizeof(info->type));
7588
- gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
7589
- }
7590
-
7591
- // we require the data section to be aligned, so take into account any padding
7592
- {
7593
- const size_t offset = buf->offset;
7594
- const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
7595
-
7596
- if (offset_pad != offset) {
7597
- uint8_t pad = 0;
7598
- for (size_t i = 0; i < offset_pad - offset; ++i) {
7599
- gguf_bwrite_el(buf, &pad, sizeof(pad));
7600
- }
7601
- }
7602
- }
7603
-
7604
- if (only_meta) {
7605
- return;
7606
- }
7607
-
7608
- size_t offset = 0;
7609
-
7610
- // write tensor data
7611
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
7612
- struct gguf_tensor_info * info = &ctx->infos[i];
7613
-
7614
- const size_t size = info->size;
7615
- const size_t size_pad = GGML_PAD(size, ctx->alignment);
7616
-
7617
- gguf_bwrite_el(buf, info->data, size);
7618
-
7619
- if (size_pad != size) {
7620
- uint8_t pad = 0;
7621
- for (size_t j = 0; j < size_pad - size; ++j) {
7622
- gguf_bwrite_el(buf, &pad, sizeof(pad));
7623
- }
7624
- }
7625
-
7626
- GGML_ASSERT(offset == info->offset);
7627
-
7628
- offset += size_pad;
7629
- }
7630
- }
7631
-
7632
- void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
7633
- FILE * file = ggml_fopen(fname, "wb");
7634
- if (!file) {
7635
- GGML_ABORT("failed to open file for writing");
7636
- }
7637
-
7638
- struct gguf_buf buf = gguf_buf_init(16*1024);
7639
-
7640
- gguf_write_to_buf(ctx, &buf, only_meta);
7641
-
7642
- fwrite(buf.data, 1, buf.offset, file);
7643
-
7644
- gguf_buf_free(buf);
7645
-
7646
- fclose(file);
7647
- }
7648
-
7649
- size_t gguf_get_meta_size(const struct gguf_context * ctx) {
7650
- // no allocs - only compute size
7651
- struct gguf_buf buf = gguf_buf_init(0);
7652
-
7653
- gguf_write_to_buf(ctx, &buf, true);
7654
-
7655
- return buf.offset;
7656
- }
7657
-
7658
- void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
7659
- struct gguf_buf buf = gguf_buf_init(16*1024);
7660
-
7661
- gguf_write_to_buf(ctx, &buf, true);
7662
-
7663
- memcpy(data, buf.data, buf.offset);
7664
-
7665
- gguf_buf_free(buf);
7666
- }
7667
-
7668
6804
  void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
7669
6805
  g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
7670
6806
  g_logger_state.log_callback_user_data = user_data;