whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -4,6 +4,7 @@
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-threading.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml.h"
8
9
 
9
10
  // FIXME: required here for quantization functions
@@ -63,12 +64,17 @@
63
64
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
64
65
  float ggml_table_f32_f16[1 << 16];
65
66
 
66
- #if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
67
- (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
67
+ #if defined(__linux__) || \
68
+ defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
69
+ (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
70
+
68
71
  #include <unistd.h>
69
72
  #include <sys/types.h>
70
73
  #include <sys/stat.h>
71
74
  #include <sys/wait.h>
75
+ #if defined(__linux__)
76
+ #include <sys/prctl.h>
77
+ #endif
72
78
 
73
79
  #if defined(__ANDROID__)
74
80
  #include <unwind.h>
@@ -128,10 +134,40 @@ static void ggml_print_backtrace_symbols(void) {
128
134
  #endif
129
135
 
130
136
  static void ggml_print_backtrace(void) {
131
- char attach[32];
132
- snprintf(attach, sizeof(attach), "attach %d", getpid());
133
- int pid = fork();
134
- if (pid == 0) {
137
+ const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
138
+ if (GGML_NO_BACKTRACE) {
139
+ return;
140
+ }
141
+ #if defined(__linux__)
142
+ FILE * f = fopen("/proc/self/status", "r");
143
+ size_t size = 0;
144
+ char * line = NULL;
145
+ ssize_t length = 0;
146
+ while ((length = getline(&line, &size, f)) > 0) {
147
+ if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
148
+ (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
149
+ // Already being debugged, and the breakpoint is the later abort()
150
+ free(line);
151
+ fclose(f);
152
+ return;
153
+ }
154
+ }
155
+ free(line);
156
+ fclose(f);
157
+ int lock[2] = { -1, -1 };
158
+ (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
159
+ #endif
160
+ const int parent_pid = getpid();
161
+ const int child_pid = fork();
162
+ if (child_pid < 0) { // error
163
+ return;
164
+ } else if (child_pid == 0) { // child
165
+ char attach[32];
166
+ snprintf(attach, sizeof(attach), "attach %d", parent_pid);
167
+ #if defined(__linux__)
168
+ close(lock[1]);
169
+ (void) !read(lock[0], lock, 1);
170
+ #endif
135
171
  // try gdb
136
172
  execlp("gdb", "gdb", "--batch",
137
173
  "-ex", "set style enabled on",
@@ -144,18 +180,18 @@ static void ggml_print_backtrace(void) {
144
180
  execlp("lldb", "lldb", "--batch",
145
181
  "-o", "bt",
146
182
  "-o", "quit",
147
- "-p", attach,
183
+ "-p", &attach[sizeof("attach ") - 1],
148
184
  (char *) NULL);
149
- exit(EXIT_FAILURE);
150
- } else {
151
- int wstatus;
152
- waitpid(pid, &wstatus, 0);
153
- if (WIFEXITED(wstatus)) {
154
- if (WEXITSTATUS(wstatus) == EXIT_FAILURE) {
155
- // gdb failed, fallback to backtrace_symbols
156
- ggml_print_backtrace_symbols();
157
- }
158
- }
185
+ // gdb failed, fallback to backtrace_symbols
186
+ ggml_print_backtrace_symbols();
187
+ _Exit(0);
188
+ } else { // parent
189
+ #if defined(__linux__)
190
+ prctl(PR_SET_PTRACER, child_pid);
191
+ close(lock[1]);
192
+ close(lock[0]);
193
+ #endif
194
+ waitpid(child_pid, NULL, 0);
159
195
  }
160
196
  }
161
197
  #else
@@ -236,7 +272,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
236
272
 
237
273
 
238
274
  void * ggml_aligned_malloc(size_t size) {
275
+ #if defined(__s390x__)
276
+ const int alignment = 256;
277
+ #else
239
278
  const int alignment = 64;
279
+ #endif
240
280
 
241
281
  #if defined(_MSC_VER) || defined(__MINGW32__)
242
282
  return _aligned_malloc(size, alignment);
@@ -374,58 +414,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
374
414
  }
375
415
  }
376
416
 
377
- // FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
378
- // currently, the ggml_cpu_has_* functions are entirely compile-time
379
417
  void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
380
- int64_t i = 0;
381
- #if defined(__F16C__)
382
- //if (ggml_cpu_has_f16c()) {
383
- for (; i + 7 < n; i += 8) {
384
- __m256 x_vec = _mm256_loadu_ps(x + i);
385
- __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
386
- _mm_storeu_si128((__m128i *)(y + i), y_vec);
387
- }
388
- for(; i + 3 < n; i += 4) {
389
- __m128 x_vec = _mm_loadu_ps(x + i);
390
- __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
391
- _mm_storel_epi64((__m128i *)(y + i), y_vec);
392
- }
393
- //}
394
- #endif
395
- for (; i < n; i++) {
418
+ int i = 0;
419
+ for (; i < n; ++i) {
396
420
  y[i] = GGML_FP32_TO_FP16(x[i]);
397
421
  }
398
422
  }
399
423
 
400
424
  void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
401
- int64_t i = 0;
402
- #if defined(__AVX512F__)
403
- //if (ggml_cpu_has_avx512()) {
404
- for (; i + 16 <= n; i += 16) {
405
- _mm512_storeu_ps(y + i,
406
- _mm512_castsi512_ps(
407
- _mm512_slli_epi32(
408
- _mm512_cvtepu16_epi32(
409
- _mm256_loadu_si256(
410
- (const __m256i *)(x + i))),
411
- 16)));
412
- }
413
- //}
414
- #endif
415
- #if defined(__AVX2__)
416
- //if (ggml_cpu_has_avx2()) {
417
- for (; i + 8 <= n; i += 8) {
418
- _mm256_storeu_ps(y + i,
419
- _mm256_castsi256_ps(
420
- _mm256_slli_epi32(
421
- _mm256_cvtepu16_epi32(
422
- _mm_loadu_si128(
423
- (const __m128i *)(x + i))),
424
- 16)));
425
- }
426
- //}
427
- #endif
428
- for (; i < n; i++) {
425
+ int i = 0;
426
+ for (; i < n; ++i) {
429
427
  y[i] = GGML_BF16_TO_FP32(x[i]);
430
428
  }
431
429
  }
@@ -557,9 +555,9 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
557
555
  #endif
558
556
 
559
557
  }
560
- static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
561
- static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
562
- static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
558
+ static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
559
+ static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
560
+ static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
563
561
 
564
562
  static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
565
563
  [GGML_TYPE_I8] = {
@@ -921,6 +919,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
921
919
  "RMS_NORM",
922
920
  "RMS_NORM_BACK",
923
921
  "GROUP_NORM",
922
+ "L2_NORM",
924
923
 
925
924
  "MUL_MAT",
926
925
  "MUL_MAT_ID",
@@ -947,6 +946,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
947
946
  "CONV_TRANSPOSE_1D",
948
947
  "IM2COL",
949
948
  "IM2COL_BACK",
949
+ "CONV_2D_DW",
950
950
  "CONV_TRANSPOSE_2D",
951
951
  "POOL_1D",
952
952
  "POOL_2D",
@@ -968,20 +968,17 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
968
968
  "GET_REL_POS",
969
969
  "ADD_REL_POS",
970
970
  "RWKV_WKV6",
971
+ "GATED_LINEAR_ATTN",
972
+ "RWKV_WKV7",
971
973
 
972
974
  "UNARY",
973
975
 
974
- "MAP_UNARY",
975
- "MAP_BINARY",
976
-
977
- "MAP_CUSTOM1_F32",
978
- "MAP_CUSTOM2_F32",
979
- "MAP_CUSTOM3_F32",
980
-
981
976
  "MAP_CUSTOM1",
982
977
  "MAP_CUSTOM2",
983
978
  "MAP_CUSTOM3",
984
979
 
980
+ "CUSTOM",
981
+
985
982
  "CROSS_ENTROPY_LOSS",
986
983
  "CROSS_ENTROPY_LOSS_BACK",
987
984
  "OPT_STEP_ADAMW",
@@ -1017,6 +1014,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1017
1014
  "rms_norm(x)",
1018
1015
  "rms_norm_back(x)",
1019
1016
  "group_norm(x)",
1017
+ "l2_norm(x)",
1020
1018
 
1021
1019
  "X*Y",
1022
1020
  "X[i]*Y",
@@ -1043,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1043
1041
  "conv_transpose_1d(x)",
1044
1042
  "im2col(x)",
1045
1043
  "im2col_back(x)",
1044
+ "conv_2d_dw(x)",
1046
1045
  "conv_transpose_2d(x)",
1047
1046
  "pool_1d(x)",
1048
1047
  "pool_2d(x)",
@@ -1064,19 +1063,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1064
1063
  "get_rel_pos(x)",
1065
1064
  "add_rel_pos(x)",
1066
1065
  "rwkv_wkv6(k, v, r, tf, td, s)",
1066
+ "gated_linear_attn(k, v, q, gate, s)",
1067
+ "rwkv_wkv7(r, w, k, v, a, b, s)",
1067
1068
 
1068
1069
  "unary(x)",
1069
1070
 
1070
- "f(x)",
1071
- "f(x,y)",
1072
-
1073
- "custom_f32(x)",
1074
- "custom_f32(x,y)",
1075
- "custom_f32(x,y,z)",
1071
+ "map_custom(x)",
1072
+ "map_custom(x,y)",
1073
+ "map_custom(x,y,z)",
1076
1074
 
1077
1075
  "custom(x)",
1078
- "custom(x,y)",
1079
- "custom(x,y,z)",
1080
1076
 
1081
1077
  "cross_entropy_loss(x,y)",
1082
1078
  "cross_entropy_loss_back(x,y)",
@@ -1103,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1103
1099
  "HARDSWISH",
1104
1100
  "HARDSIGMOID",
1105
1101
  "EXP",
1102
+ "GELU_ERF",
1106
1103
  };
1107
1104
 
1108
- static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
1105
+ static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1109
1106
 
1110
1107
 
1111
1108
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -1145,6 +1142,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
1145
1142
  }
1146
1143
 
1147
1144
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
1145
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1146
+ if (tensor->ne[i] <= 0) {
1147
+ return 0;
1148
+ }
1149
+ }
1150
+
1148
1151
  size_t nbytes;
1149
1152
  const size_t blck_size = ggml_blck_size(tensor->type);
1150
1153
  if (blck_size == 1) {
@@ -1328,12 +1331,23 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
1328
1331
  return ggml_is_contiguous_n(tensor, 2);
1329
1332
  }
1330
1333
 
1334
+ bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
1335
+ return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
1336
+ }
1337
+
1331
1338
  bool ggml_is_permuted(const struct ggml_tensor * tensor) {
1332
1339
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1333
1340
 
1334
1341
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1335
1342
  }
1336
1343
 
1344
+ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1345
+ return
1346
+ tensor->nb[0] > tensor->nb[2] &&
1347
+ tensor->nb[1] > tensor->nb[0] &&
1348
+ tensor->nb[2] == ggml_type_size(tensor->type);
1349
+ }
1350
+
1337
1351
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1338
1352
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1339
1353
 
@@ -1373,7 +1387,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
1373
1387
  (t0->nb[3] == t1->nb[3]);
1374
1388
  }
1375
1389
 
1376
- // check if t1 can be represented as a repeatition of t0
1390
+ // check if t1 can be represented as a repetition of t0
1377
1391
  bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
1378
1392
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1379
1393
 
@@ -1588,15 +1602,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
1588
1602
 
1589
1603
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
1590
1604
 
1591
- #ifdef __clang__
1592
- // temporary until ggml_tensor::backend is removed
1593
- #pragma clang diagnostic push
1594
- #pragma clang diagnostic ignored "-Wdeprecated-declarations"
1595
- #endif
1596
-
1597
1605
  *result = (struct ggml_tensor) {
1598
1606
  /*.type =*/ type,
1599
- /*.backend =*/ GGML_BACKEND_TYPE_CPU,
1600
1607
  /*.buffer =*/ NULL,
1601
1608
  /*.ne =*/ { 1, 1, 1, 1 },
1602
1609
  /*.nb =*/ { 0, 0, 0, 0 },
@@ -1612,10 +1619,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
1612
1619
  /*.padding =*/ { 0 },
1613
1620
  };
1614
1621
 
1615
- #ifdef __clang__
1616
- #pragma clang diagnostic pop
1617
- #endif
1618
-
1619
1622
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
1620
1623
  //GGML_ASSERT_ALIGNED(result->data);
1621
1624
 
@@ -2333,6 +2336,7 @@ struct ggml_tensor * ggml_concat(
2333
2336
  struct ggml_tensor * b,
2334
2337
  int dim) {
2335
2338
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
2339
+ GGML_ASSERT(a->type == b->type);
2336
2340
 
2337
2341
  int64_t ne[GGML_MAX_DIMS];
2338
2342
  for (int d = 0; d < GGML_MAX_DIMS; ++d) {
@@ -2498,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
2498
2502
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
2499
2503
  }
2500
2504
 
2505
+ // ggml_gelu_erf
2506
+
2507
+ struct ggml_tensor * ggml_gelu_erf(
2508
+ struct ggml_context * ctx,
2509
+ struct ggml_tensor * a) {
2510
+ return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
2511
+ }
2512
+
2513
+ struct ggml_tensor * ggml_gelu_erf_inplace(
2514
+ struct ggml_context * ctx,
2515
+ struct ggml_tensor * a) {
2516
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
2517
+ }
2518
+
2501
2519
  // ggml_gelu_quick
2502
2520
 
2503
2521
  struct ggml_tensor * ggml_gelu_quick(
@@ -2686,6 +2704,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
2686
2704
  return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
2687
2705
  }
2688
2706
 
2707
+ // ggml_l2_norm
2708
+
2709
+ static struct ggml_tensor * ggml_l2_norm_impl(
2710
+ struct ggml_context * ctx,
2711
+ struct ggml_tensor * a,
2712
+ float eps,
2713
+ bool inplace) {
2714
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2715
+
2716
+ ggml_set_op_params_f32(result, 0, eps);
2717
+
2718
+ result->op = GGML_OP_L2_NORM;
2719
+ result->src[0] = a;
2720
+
2721
+ return result;
2722
+ }
2723
+
2724
+ struct ggml_tensor * ggml_l2_norm(
2725
+ struct ggml_context * ctx,
2726
+ struct ggml_tensor * a,
2727
+ float eps) {
2728
+ return ggml_l2_norm_impl(ctx, a, eps, false);
2729
+ }
2730
+
2731
+ struct ggml_tensor * ggml_l2_norm_inplace(
2732
+ struct ggml_context * ctx,
2733
+ struct ggml_tensor * a,
2734
+ float eps) {
2735
+ return ggml_l2_norm_impl(ctx, a, eps, true);
2736
+ }
2737
+
2689
2738
  // ggml_mul_mat
2690
2739
 
2691
2740
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -2729,11 +2778,11 @@ void ggml_mul_mat_set_prec(
2729
2778
  c = ggml_mul_mat_id(ctx, as, b, ids);
2730
2779
 
2731
2780
  as -> [cols, rows, n_expert]
2732
- ids -> [n_experts_used, n_tokens] (i32)
2733
2781
  b -> [cols, n_expert_used, n_tokens]
2782
+ ids -> [n_expert_used, n_tokens] (i32)
2734
2783
  c -> [rows, n_expert_used, n_tokens]
2735
2784
 
2736
- in b, n_experts_used can be broadcasted to match the n_expert_used of ids
2785
+ in b, n_expert_used can be broadcasted to match the n_expert_used of ids
2737
2786
 
2738
2787
  c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
2739
2788
  */
@@ -3459,12 +3508,14 @@ struct ggml_tensor * ggml_soft_max_ext(
3459
3508
  return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3460
3509
  }
3461
3510
 
3462
- // ggml_soft_max_back
3511
+ // ggml_soft_max_ext_back
3463
3512
 
3464
- static struct ggml_tensor * ggml_soft_max_back_impl(
3513
+ static struct ggml_tensor * ggml_soft_max_ext_back_impl(
3465
3514
  struct ggml_context * ctx,
3466
3515
  struct ggml_tensor * a,
3467
3516
  struct ggml_tensor * b,
3517
+ float scale,
3518
+ float max_bias,
3468
3519
  bool inplace) {
3469
3520
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3470
3521
 
@@ -3472,21 +3523,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
3472
3523
  result->src[0] = a;
3473
3524
  result->src[1] = b;
3474
3525
 
3526
+ memcpy((float *) result->op_params + 0, &scale, sizeof(float));
3527
+ memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
3528
+
3475
3529
  return result;
3476
3530
  }
3477
3531
 
3478
- struct ggml_tensor * ggml_soft_max_back(
3532
+ struct ggml_tensor * ggml_soft_max_ext_back(
3479
3533
  struct ggml_context * ctx,
3480
3534
  struct ggml_tensor * a,
3481
- struct ggml_tensor * b) {
3482
- return ggml_soft_max_back_impl(ctx, a, b, false);
3535
+ struct ggml_tensor * b,
3536
+ float scale,
3537
+ float max_bias) {
3538
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
3483
3539
  }
3484
3540
 
3485
- struct ggml_tensor * ggml_soft_max_back_inplace(
3541
+ struct ggml_tensor * ggml_soft_max_ext_back_inplace(
3486
3542
  struct ggml_context * ctx,
3487
3543
  struct ggml_tensor * a,
3488
- struct ggml_tensor * b) {
3489
- return ggml_soft_max_back_impl(ctx, a, b, true);
3544
+ struct ggml_tensor * b,
3545
+ float scale,
3546
+ float max_bias) {
3547
+ return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
3490
3548
  }
3491
3549
 
3492
3550
  // ggml_rope
@@ -3704,7 +3762,7 @@ void ggml_rope_yarn_corr_dims(
3704
3762
 
3705
3763
  // ggml_rope_back
3706
3764
 
3707
- struct ggml_tensor * ggml_rope_back(
3765
+ struct ggml_tensor * ggml_rope_ext_back(
3708
3766
  struct ggml_context * ctx,
3709
3767
  struct ggml_tensor * a,
3710
3768
  struct ggml_tensor * b,
@@ -3718,29 +3776,32 @@ struct ggml_tensor * ggml_rope_back(
3718
3776
  float attn_factor,
3719
3777
  float beta_fast,
3720
3778
  float beta_slow) {
3721
- GGML_ASSERT(ggml_is_vector(b));
3722
- GGML_ASSERT(b->type == GGML_TYPE_I32);
3723
- GGML_ASSERT(a->ne[2] == b->ne[0]);
3724
-
3725
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3726
-
3727
- int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3728
- memcpy(params + 5, &freq_base, sizeof(float));
3729
- memcpy(params + 6, &freq_scale, sizeof(float));
3730
- memcpy(params + 7, &ext_factor, sizeof(float));
3731
- memcpy(params + 8, &attn_factor, sizeof(float));
3732
- memcpy(params + 9, &beta_fast, sizeof(float));
3733
- memcpy(params + 10, &beta_slow, sizeof(float));
3734
- ggml_set_op_params(result, params, sizeof(params));
3735
-
3736
- result->op = GGML_OP_ROPE_BACK;
3737
- result->src[0] = a;
3738
- result->src[1] = b;
3739
- result->src[2] = c;
3740
-
3779
+ struct ggml_tensor * result = ggml_rope_ext(
3780
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3781
+ result->op = GGML_OP_ROPE_BACK;
3741
3782
  return result;
3742
3783
  }
3743
3784
 
3785
+ struct ggml_tensor * ggml_rope_multi_back(
3786
+ struct ggml_context * ctx,
3787
+ struct ggml_tensor * a,
3788
+ struct ggml_tensor * b,
3789
+ struct ggml_tensor * c,
3790
+ int n_dims,
3791
+ int sections[4],
3792
+ int mode,
3793
+ int n_ctx_orig,
3794
+ float freq_base,
3795
+ float freq_scale,
3796
+ float ext_factor,
3797
+ float attn_factor,
3798
+ float beta_fast,
3799
+ float beta_slow) {
3800
+ struct ggml_tensor * result = ggml_rope_multi(
3801
+ ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
3802
+ result->op = GGML_OP_ROPE_BACK;
3803
+ return result;
3804
+ }
3744
3805
  // ggml_clamp
3745
3806
 
3746
3807
  struct ggml_tensor * ggml_clamp(
@@ -3760,104 +3821,10 @@ struct ggml_tensor * ggml_clamp(
3760
3821
  return result;
3761
3822
  }
3762
3823
 
3763
- // ggml_conv_1d
3764
-
3765
3824
  static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3766
3825
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
3767
3826
  }
3768
3827
 
3769
- GGML_API struct ggml_tensor * ggml_conv_1d(
3770
- struct ggml_context * ctx,
3771
- struct ggml_tensor * a,
3772
- struct ggml_tensor * b,
3773
- int s0,
3774
- int p0,
3775
- int d0) {
3776
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
3777
-
3778
- struct ggml_tensor * result =
3779
- ggml_mul_mat(ctx,
3780
- ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
3781
- ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
3782
-
3783
- result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
3784
-
3785
- return result;
3786
- }
3787
-
3788
- // ggml_conv_1d_ph
3789
-
3790
- struct ggml_tensor* ggml_conv_1d_ph(
3791
- struct ggml_context * ctx,
3792
- struct ggml_tensor * a,
3793
- struct ggml_tensor * b,
3794
- int s,
3795
- int d) {
3796
- return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
3797
- }
3798
-
3799
- // ggml_conv_transpose_1d
3800
-
3801
- static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3802
- return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
3803
- }
3804
-
3805
- GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
3806
- struct ggml_context * ctx,
3807
- struct ggml_tensor * a,
3808
- struct ggml_tensor * b,
3809
- int s0,
3810
- int p0,
3811
- int d0) {
3812
- GGML_ASSERT(ggml_is_matrix(b));
3813
- GGML_ASSERT(a->ne[2] == b->ne[1]);
3814
- GGML_ASSERT(a->ne[3] == 1);
3815
-
3816
- GGML_ASSERT(p0 == 0);
3817
- GGML_ASSERT(d0 == 1);
3818
-
3819
- const int64_t ne[4] = {
3820
- ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
3821
- a->ne[1], b->ne[2], 1,
3822
- };
3823
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3824
-
3825
- int32_t params[] = { s0, p0, d0 };
3826
- ggml_set_op_params(result, params, sizeof(params));
3827
-
3828
- result->op = GGML_OP_CONV_TRANSPOSE_1D;
3829
- result->src[0] = a;
3830
- result->src[1] = b;
3831
-
3832
- return result;
3833
- }
3834
-
3835
- // ggml_conv_depthwise
3836
-
3837
- struct ggml_tensor * ggml_conv_depthwise_2d(
3838
- struct ggml_context * ctx,
3839
- struct ggml_tensor * a,
3840
- struct ggml_tensor * b,
3841
- int s0,
3842
- int s1,
3843
- int p0,
3844
- int p1,
3845
- int d0,
3846
- int d1) {
3847
- struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
3848
- struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
3849
- ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
3850
- s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
3851
- struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
3852
-
3853
- new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
3854
- struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
3855
- result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
3856
-
3857
- return result;
3858
- }
3859
- // ggml_conv_2d
3860
-
3861
3828
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
3862
3829
  // a: [OC,IC, KH, KW]
3863
3830
  // b: [N, IC, IH, IW]
@@ -3874,10 +3841,11 @@ struct ggml_tensor * ggml_im2col(
3874
3841
  int d1,
3875
3842
  bool is_2D,
3876
3843
  enum ggml_type dst_type) {
3877
- if(is_2D) {
3844
+ if (is_2D) {
3878
3845
  GGML_ASSERT(a->ne[2] == b->ne[2]);
3879
3846
  } else {
3880
- GGML_ASSERT(a->ne[1] == b->ne[1]);
3847
+ //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
3848
+ GGML_ASSERT(b->ne[1] == a->ne[1]);
3881
3849
  GGML_ASSERT(b->ne[3] == 1);
3882
3850
  }
3883
3851
 
@@ -3928,58 +3896,225 @@ struct ggml_tensor * ggml_im2col_back(
3928
3896
  return result;
3929
3897
  }
3930
3898
 
3931
- // a: [OC,IC, KH, KW]
3932
- // b: [N, IC, IH, IW]
3933
- // result: [N, OC, OH, OW]
3934
- struct ggml_tensor * ggml_conv_2d(
3899
+ // ggml_conv_1d
3900
+
3901
+ struct ggml_tensor * ggml_conv_1d(
3935
3902
  struct ggml_context * ctx,
3936
3903
  struct ggml_tensor * a,
3937
3904
  struct ggml_tensor * b,
3938
3905
  int s0,
3939
- int s1,
3940
3906
  int p0,
3941
- int p1,
3942
- int d0,
3943
- int d1) {
3944
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
3907
+ int d0) {
3908
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
3945
3909
 
3946
3910
  struct ggml_tensor * result =
3947
3911
  ggml_mul_mat(ctx,
3948
- ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
3949
- ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
3950
-
3951
- result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
3952
- result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
3912
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
3913
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
3953
3914
 
3915
+ result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
3954
3916
 
3955
3917
  return result;
3956
3918
  }
3957
3919
 
3958
- // ggml_conv_2d_sk_p0
3920
+ // ggml_conv_1d_ph
3959
3921
 
3960
- struct ggml_tensor * ggml_conv_2d_sk_p0(
3922
+ struct ggml_tensor* ggml_conv_1d_ph(
3961
3923
  struct ggml_context * ctx,
3962
3924
  struct ggml_tensor * a,
3963
- struct ggml_tensor * b) {
3964
- return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
3925
+ struct ggml_tensor * b,
3926
+ int s,
3927
+ int d) {
3928
+ return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
3965
3929
  }
3966
3930
 
3967
- // ggml_conv_2d_s1_ph
3931
+ // ggml_conv_1d_dw
3968
3932
 
3969
- struct ggml_tensor * ggml_conv_2d_s1_ph(
3933
+ struct ggml_tensor * ggml_conv_1d_dw(
3970
3934
  struct ggml_context * ctx,
3971
3935
  struct ggml_tensor * a,
3972
- struct ggml_tensor * b) {
3973
- return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
3974
- }
3936
+ struct ggml_tensor * b,
3937
+ int s0,
3938
+ int p0,
3939
+ int d0) {
3940
+ struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
3941
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
3975
3942
 
3976
- // ggml_conv_transpose_2d_p0
3943
+ struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
3977
3944
 
3978
- static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
3979
- return (ins - 1) * s - 2 * p + ks;
3945
+ struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
3946
+
3947
+ result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
3948
+
3949
+ return result;
3980
3950
  }
3981
3951
 
3982
- struct ggml_tensor * ggml_conv_transpose_2d_p0(
3952
+ // ggml_conv_1d_dw_ph
3953
+
3954
+ struct ggml_tensor * ggml_conv_1d_dw_ph(
3955
+ struct ggml_context * ctx,
3956
+ struct ggml_tensor * a,
3957
+ struct ggml_tensor * b,
3958
+ int s0,
3959
+ int d0) {
3960
+ return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
3961
+ }
3962
+
3963
+ // ggml_conv_transpose_1d
3964
+
3965
+ static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
3966
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
3967
+ }
3968
+
3969
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
3970
+ struct ggml_context * ctx,
3971
+ struct ggml_tensor * a,
3972
+ struct ggml_tensor * b,
3973
+ int s0,
3974
+ int p0,
3975
+ int d0) {
3976
+ GGML_ASSERT(ggml_is_matrix(b));
3977
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
3978
+ GGML_ASSERT(a->ne[3] == 1);
3979
+
3980
+ GGML_ASSERT(p0 == 0);
3981
+ GGML_ASSERT(d0 == 1);
3982
+
3983
+ const int64_t ne[4] = {
3984
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
3985
+ a->ne[1], b->ne[2], 1,
3986
+ };
3987
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3988
+
3989
+ int32_t params[] = { s0, p0, d0 };
3990
+ ggml_set_op_params(result, params, sizeof(params));
3991
+
3992
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
3993
+ result->src[0] = a;
3994
+ result->src[1] = b;
3995
+
3996
+ return result;
3997
+ }
3998
+
3999
+ // ggml_conv_2d
4000
+
4001
+ // a: [OC,IC, KH, KW]
4002
+ // b: [N, IC, IH, IW]
4003
+ // result: [N, OC, OH, OW]
4004
+ struct ggml_tensor * ggml_conv_2d(
4005
+ struct ggml_context * ctx,
4006
+ struct ggml_tensor * a,
4007
+ struct ggml_tensor * b,
4008
+ int s0,
4009
+ int s1,
4010
+ int p0,
4011
+ int p1,
4012
+ int d0,
4013
+ int d1) {
4014
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
4015
+
4016
+ struct ggml_tensor * result =
4017
+ ggml_mul_mat(ctx,
4018
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
4019
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
4020
+
4021
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
4022
+ result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
4023
+
4024
+
4025
+ return result;
4026
+ }
4027
+
4028
+ // ggml_conv_2d_sk_p0
4029
+
4030
+ struct ggml_tensor * ggml_conv_2d_sk_p0(
4031
+ struct ggml_context * ctx,
4032
+ struct ggml_tensor * a,
4033
+ struct ggml_tensor * b) {
4034
+ return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
4035
+ }
4036
+
4037
+ // ggml_conv_2d_s1_ph
4038
+
4039
+ struct ggml_tensor * ggml_conv_2d_s1_ph(
4040
+ struct ggml_context * ctx,
4041
+ struct ggml_tensor * a,
4042
+ struct ggml_tensor * b) {
4043
+ return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
4044
+ }
4045
+
4046
+ // ggml_conv_2d_dw
4047
+
4048
+ struct ggml_tensor * ggml_conv_2d_dw(
4049
+ struct ggml_context * ctx,
4050
+ struct ggml_tensor * a,
4051
+ struct ggml_tensor * b,
4052
+ int s0,
4053
+ int s1,
4054
+ int p0,
4055
+ int p1,
4056
+ int d0,
4057
+ int d1) {
4058
+ struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
4059
+ struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
4060
+ ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
4061
+ s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
4062
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
4063
+
4064
+ new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
4065
+ struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
4066
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
4067
+
4068
+ return result;
4069
+ }
4070
+
4071
+ // ggml_conv_2d_dw_direct
4072
+
4073
+ struct ggml_tensor * ggml_conv_2d_dw_direct(
4074
+ struct ggml_context * ctx,
4075
+ struct ggml_tensor * a,
4076
+ struct ggml_tensor * b,
4077
+ int stride0,
4078
+ int stride1,
4079
+ int pad0,
4080
+ int pad1,
4081
+ int dilation0,
4082
+ int dilation1) {
4083
+ GGML_ASSERT(a->ne[2] == 1);
4084
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
4085
+ int64_t ne[4];
4086
+ ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4087
+ ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4088
+ ne[2] = b->ne[2];
4089
+ ne[3] = b->ne[3];
4090
+
4091
+ struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4092
+
4093
+ if (ggml_is_contiguous_channels(b)) {
4094
+ // Result will be permuted the same way as input (CWHN order)
4095
+ const int64_t type_size = ggml_type_size(result->type);
4096
+ GGML_ASSERT(ggml_blck_size(result->type) == 1);
4097
+ result->nb[0] = result->ne[2] * type_size;
4098
+ result->nb[1] = result->ne[0] * result->nb[0];
4099
+ result->nb[2] = type_size;
4100
+ }
4101
+
4102
+ int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4103
+ ggml_set_op_params(result, params, sizeof(params));
4104
+
4105
+ result->op = GGML_OP_CONV_2D_DW;
4106
+ result->src[0] = a;
4107
+ result->src[1] = b;
4108
+ return result;
4109
+ }
4110
+
4111
+ // ggml_conv_transpose_2d_p0
4112
+
4113
+ static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
4114
+ return (ins - 1) * s - 2 * p + ks;
4115
+ }
4116
+
4117
+ struct ggml_tensor * ggml_conv_transpose_2d_p0(
3983
4118
  struct ggml_context * ctx,
3984
4119
  struct ggml_tensor * a,
3985
4120
  struct ggml_tensor * b,
@@ -4097,7 +4232,8 @@ static struct ggml_tensor * ggml_upscale_impl(
4097
4232
  int ne0,
4098
4233
  int ne1,
4099
4234
  int ne2,
4100
- int ne3) {
4235
+ int ne3,
4236
+ enum ggml_scale_mode mode) {
4101
4237
  GGML_ASSERT(a->ne[0] <= ne0);
4102
4238
  GGML_ASSERT(a->ne[1] <= ne1);
4103
4239
  GGML_ASSERT(a->ne[2] <= ne2);
@@ -4105,6 +4241,8 @@ static struct ggml_tensor * ggml_upscale_impl(
4105
4241
 
4106
4242
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4107
4243
 
4244
+ ggml_set_op_params_i32(result, 0, mode);
4245
+
4108
4246
  result->op = GGML_OP_UPSCALE;
4109
4247
  result->src[0] = a;
4110
4248
 
@@ -4114,8 +4252,9 @@ static struct ggml_tensor * ggml_upscale_impl(
4114
4252
  struct ggml_tensor * ggml_upscale(
4115
4253
  struct ggml_context * ctx,
4116
4254
  struct ggml_tensor * a,
4117
- int scale_factor) {
4118
- return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4255
+ int scale_factor,
4256
+ enum ggml_scale_mode mode) {
4257
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4119
4258
  }
4120
4259
 
4121
4260
  struct ggml_tensor * ggml_upscale_ext(
@@ -4124,8 +4263,9 @@ struct ggml_tensor * ggml_upscale_ext(
4124
4263
  int ne0,
4125
4264
  int ne1,
4126
4265
  int ne2,
4127
- int ne3) {
4128
- return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4266
+ int ne3,
4267
+ enum ggml_scale_mode mode) {
4268
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4129
4269
  }
4130
4270
 
4131
4271
  // ggml_pad
@@ -4288,7 +4428,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
4288
4428
  }
4289
4429
 
4290
4430
  // permute(0, 2, 1, 3)
4291
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4431
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4292
4432
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4293
4433
 
4294
4434
  float params[] = { scale, max_bias, logit_softcap };
@@ -4606,15 +4746,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
4606
4746
  GGML_ASSERT(ggml_is_contiguous(state));
4607
4747
 
4608
4748
  const int64_t S = k->ne[0];
4609
- const int64_t H = k->ne[2];
4610
- const int64_t n_tokens = k->ne[3];
4749
+ const int64_t H = k->ne[1];
4750
+ const int64_t n_tokens = k->ne[2];
4611
4751
  const int64_t n_seqs = state->ne[1];
4612
4752
  {
4613
- GGML_ASSERT(k->ne[1] == 1);
4614
- GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
4615
- GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
4616
- // TODO: RWKV v4 and v5
4617
- GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
4753
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4754
+ GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
4755
+ GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
4618
4756
  GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4619
4757
  }
4620
4758
 
@@ -4633,6 +4771,97 @@ struct ggml_tensor * ggml_rwkv_wkv6(
4633
4771
  return result;
4634
4772
  }
4635
4773
 
4774
+ // ggml_gated_linear_attn
4775
+
4776
+ struct ggml_tensor * ggml_gated_linear_attn(
4777
+ struct ggml_context * ctx,
4778
+ struct ggml_tensor * k,
4779
+ struct ggml_tensor * v,
4780
+ struct ggml_tensor * q,
4781
+ struct ggml_tensor * g,
4782
+ struct ggml_tensor * state,
4783
+ float scale) {
4784
+ GGML_ASSERT(ggml_is_contiguous(k));
4785
+ GGML_ASSERT(ggml_is_contiguous(v));
4786
+ GGML_ASSERT(ggml_is_contiguous(q));
4787
+ GGML_ASSERT(ggml_is_contiguous(g));
4788
+ GGML_ASSERT(ggml_is_contiguous(state));
4789
+
4790
+ const int64_t S = k->ne[0];
4791
+ const int64_t H = k->ne[1];
4792
+ const int64_t n_tokens = k->ne[2];
4793
+ const int64_t n_seqs = state->ne[1];
4794
+ {
4795
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4796
+ GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
4797
+ GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
4798
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4799
+ }
4800
+
4801
+ // concat output and new_state
4802
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4803
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4804
+
4805
+ ggml_set_op_params_f32(result, 0, scale);
4806
+
4807
+ result->op = GGML_OP_GATED_LINEAR_ATTN;
4808
+ result->src[0] = k;
4809
+ result->src[1] = v;
4810
+ result->src[2] = q;
4811
+ result->src[3] = g;
4812
+ result->src[4] = state;
4813
+
4814
+ return result;
4815
+ }
4816
+
4817
+ // ggml_rwkv_wkv7
4818
+
4819
+ struct ggml_tensor * ggml_rwkv_wkv7(
4820
+ struct ggml_context * ctx,
4821
+ struct ggml_tensor * r,
4822
+ struct ggml_tensor * w,
4823
+ struct ggml_tensor * k,
4824
+ struct ggml_tensor * v,
4825
+ struct ggml_tensor * a,
4826
+ struct ggml_tensor * b,
4827
+ struct ggml_tensor * state) {
4828
+ GGML_ASSERT(ggml_is_contiguous(r));
4829
+ GGML_ASSERT(ggml_is_contiguous(w));
4830
+ GGML_ASSERT(ggml_is_contiguous(k));
4831
+ GGML_ASSERT(ggml_is_contiguous(v));
4832
+ GGML_ASSERT(ggml_is_contiguous(a));
4833
+ GGML_ASSERT(ggml_is_contiguous(b));
4834
+ GGML_ASSERT(ggml_is_contiguous(state));
4835
+
4836
+ const int64_t S = k->ne[0];
4837
+ const int64_t H = k->ne[1];
4838
+ const int64_t n_tokens = k->ne[2];
4839
+ const int64_t n_seqs = state->ne[1];
4840
+ {
4841
+ GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
4842
+ GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
4843
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
4844
+ GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
4845
+ GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
4846
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
4847
+ }
4848
+
4849
+ // concat output and new_state
4850
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4851
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4852
+
4853
+ result->op = GGML_OP_RWKV_WKV7;
4854
+ result->src[0] = r;
4855
+ result->src[1] = w;
4856
+ result->src[2] = k;
4857
+ result->src[3] = v;
4858
+ result->src[4] = a;
4859
+ result->src[5] = b;
4860
+ result->src[6] = state;
4861
+
4862
+ return result;
4863
+ }
4864
+
4636
4865
  // ggml_unary
4637
4866
 
4638
4867
  static struct ggml_tensor * ggml_unary_impl(
@@ -4666,185 +4895,106 @@ struct ggml_tensor * ggml_unary_inplace(
4666
4895
  return ggml_unary_impl(ctx, a, op, true);
4667
4896
  }
4668
4897
 
4669
- // ggml_map_unary
4898
+ // ggml_map_custom1
4899
+
4900
+ static struct ggml_tensor * ggml_map_custom1_impl(
4901
+ struct ggml_context * ctx,
4902
+ struct ggml_tensor * a,
4903
+ const ggml_custom1_op_t fun,
4904
+ int n_tasks,
4905
+ void * userdata,
4906
+ bool inplace) {
4907
+ GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
4670
4908
 
4671
- static struct ggml_tensor * ggml_map_unary_impl_f32(
4672
- struct ggml_context * ctx,
4673
- struct ggml_tensor * a,
4674
- const ggml_unary_op_f32_t fun,
4675
- bool inplace) {
4676
4909
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4677
4910
 
4678
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4911
+ struct ggml_map_custom1_op_params params = {
4912
+ /*.fun =*/ fun,
4913
+ /*.n_tasks =*/ n_tasks,
4914
+ /*.userdata =*/ userdata
4915
+ };
4916
+ ggml_set_op_params(result, &params, sizeof(params));
4679
4917
 
4680
- result->op = GGML_OP_MAP_UNARY;
4918
+ result->op = GGML_OP_MAP_CUSTOM1;
4681
4919
  result->src[0] = a;
4682
4920
 
4683
4921
  return result;
4684
4922
  }
4685
4923
 
4686
- struct ggml_tensor * ggml_map_unary_f32(
4687
- struct ggml_context * ctx,
4688
- struct ggml_tensor * a,
4689
- const ggml_unary_op_f32_t fun) {
4690
- return ggml_map_unary_impl_f32(ctx, a, fun, false);
4924
+ struct ggml_tensor * ggml_map_custom1(
4925
+ struct ggml_context * ctx,
4926
+ struct ggml_tensor * a,
4927
+ const ggml_custom1_op_t fun,
4928
+ int n_tasks,
4929
+ void * userdata) {
4930
+ return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
4691
4931
  }
4692
4932
 
4693
- struct ggml_tensor * ggml_map_unary_inplace_f32(
4694
- struct ggml_context * ctx,
4695
- struct ggml_tensor * a,
4696
- const ggml_unary_op_f32_t fun) {
4697
- return ggml_map_unary_impl_f32(ctx, a, fun, true);
4933
+ struct ggml_tensor * ggml_map_custom1_inplace(
4934
+ struct ggml_context * ctx,
4935
+ struct ggml_tensor * a,
4936
+ const ggml_custom1_op_t fun,
4937
+ int n_tasks,
4938
+ void * userdata) {
4939
+ return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
4698
4940
  }
4699
4941
 
4700
- // ggml_map_binary
4942
+ // ggml_map_custom2
4701
4943
 
4702
- static struct ggml_tensor * ggml_map_binary_impl_f32(
4703
- struct ggml_context * ctx,
4704
- struct ggml_tensor * a,
4705
- struct ggml_tensor * b,
4706
- const ggml_binary_op_f32_t fun,
4707
- bool inplace) {
4708
- GGML_ASSERT(ggml_are_same_shape(a, b));
4944
+ static struct ggml_tensor * ggml_map_custom2_impl(
4945
+ struct ggml_context * ctx,
4946
+ struct ggml_tensor * a,
4947
+ struct ggml_tensor * b,
4948
+ const ggml_custom2_op_t fun,
4949
+ int n_tasks,
4950
+ void * userdata,
4951
+ bool inplace) {
4952
+ GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
4709
4953
 
4710
4954
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4711
4955
 
4712
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4956
+ struct ggml_map_custom2_op_params params = {
4957
+ /*.fun =*/ fun,
4958
+ /*.n_tasks =*/ n_tasks,
4959
+ /*.userdata =*/ userdata
4960
+ };
4961
+ ggml_set_op_params(result, &params, sizeof(params));
4713
4962
 
4714
- result->op = GGML_OP_MAP_BINARY;
4963
+ result->op = GGML_OP_MAP_CUSTOM2;
4715
4964
  result->src[0] = a;
4716
4965
  result->src[1] = b;
4717
4966
 
4718
4967
  return result;
4719
4968
  }
4720
4969
 
4721
- struct ggml_tensor * ggml_map_binary_f32(
4722
- struct ggml_context * ctx,
4723
- struct ggml_tensor * a,
4724
- struct ggml_tensor * b,
4725
- const ggml_binary_op_f32_t fun) {
4726
- return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4970
+ struct ggml_tensor * ggml_map_custom2(
4971
+ struct ggml_context * ctx,
4972
+ struct ggml_tensor * a,
4973
+ struct ggml_tensor * b,
4974
+ const ggml_custom2_op_t fun,
4975
+ int n_tasks,
4976
+ void * userdata) {
4977
+ return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
4727
4978
  }
4728
4979
 
4729
- struct ggml_tensor * ggml_map_binary_inplace_f32(
4730
- struct ggml_context * ctx,
4731
- struct ggml_tensor * a,
4732
- struct ggml_tensor * b,
4733
- const ggml_binary_op_f32_t fun) {
4734
- return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4980
+ struct ggml_tensor * ggml_map_custom2_inplace(
4981
+ struct ggml_context * ctx,
4982
+ struct ggml_tensor * a,
4983
+ struct ggml_tensor * b,
4984
+ const ggml_custom2_op_t fun,
4985
+ int n_tasks,
4986
+ void * userdata) {
4987
+ return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
4735
4988
  }
4736
4989
 
4737
- // ggml_map_custom1_f32
4738
-
4739
- static struct ggml_tensor * ggml_map_custom1_impl_f32(
4740
- struct ggml_context * ctx,
4741
- struct ggml_tensor * a,
4742
- const ggml_custom1_op_f32_t fun,
4743
- bool inplace) {
4744
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4745
-
4746
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4747
-
4748
- result->op = GGML_OP_MAP_CUSTOM1_F32;
4749
- result->src[0] = a;
4750
-
4751
- return result;
4752
- }
4753
-
4754
- struct ggml_tensor * ggml_map_custom1_f32(
4755
- struct ggml_context * ctx,
4756
- struct ggml_tensor * a,
4757
- const ggml_custom1_op_f32_t fun) {
4758
- return ggml_map_custom1_impl_f32(ctx, a, fun, false);
4759
- }
4760
-
4761
- struct ggml_tensor * ggml_map_custom1_inplace_f32(
4762
- struct ggml_context * ctx,
4763
- struct ggml_tensor * a,
4764
- const ggml_custom1_op_f32_t fun) {
4765
- return ggml_map_custom1_impl_f32(ctx, a, fun, true);
4766
- }
4767
-
4768
- // ggml_map_custom2_f32
4769
-
4770
- static struct ggml_tensor * ggml_map_custom2_impl_f32(
4771
- struct ggml_context * ctx,
4772
- struct ggml_tensor * a,
4773
- struct ggml_tensor * b,
4774
- const ggml_custom2_op_f32_t fun,
4775
- bool inplace) {
4776
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4777
-
4778
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4779
-
4780
- result->op = GGML_OP_MAP_CUSTOM2_F32;
4781
- result->src[0] = a;
4782
- result->src[1] = b;
4783
-
4784
- return result;
4785
- }
4786
-
4787
- struct ggml_tensor * ggml_map_custom2_f32(
4788
- struct ggml_context * ctx,
4789
- struct ggml_tensor * a,
4790
- struct ggml_tensor * b,
4791
- const ggml_custom2_op_f32_t fun) {
4792
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
4793
- }
4794
-
4795
- struct ggml_tensor * ggml_map_custom2_inplace_f32(
4796
- struct ggml_context * ctx,
4797
- struct ggml_tensor * a,
4798
- struct ggml_tensor * b,
4799
- const ggml_custom2_op_f32_t fun) {
4800
- return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
4801
- }
4802
-
4803
- // ggml_map_custom3_f32
4804
-
4805
- static struct ggml_tensor * ggml_map_custom3_impl_f32(
4806
- struct ggml_context * ctx,
4807
- struct ggml_tensor * a,
4808
- struct ggml_tensor * b,
4809
- struct ggml_tensor * c,
4810
- const ggml_custom3_op_f32_t fun,
4811
- bool inplace) {
4812
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4813
-
4814
- ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
4815
-
4816
- result->op = GGML_OP_MAP_CUSTOM3_F32;
4817
- result->src[0] = a;
4818
- result->src[1] = b;
4819
- result->src[2] = c;
4820
-
4821
- return result;
4822
- }
4823
-
4824
- struct ggml_tensor * ggml_map_custom3_f32(
4825
- struct ggml_context * ctx,
4826
- struct ggml_tensor * a,
4827
- struct ggml_tensor * b,
4828
- struct ggml_tensor * c,
4829
- const ggml_custom3_op_f32_t fun) {
4830
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
4831
- }
4832
-
4833
- struct ggml_tensor * ggml_map_custom3_inplace_f32(
4834
- struct ggml_context * ctx,
4835
- struct ggml_tensor * a,
4836
- struct ggml_tensor * b,
4837
- struct ggml_tensor * c,
4838
- const ggml_custom3_op_f32_t fun) {
4839
- return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
4840
- }
4841
-
4842
- // ggml_map_custom1
4990
+ // ggml_map_custom3
4843
4991
 
4844
- static struct ggml_tensor * ggml_map_custom1_impl(
4992
+ static struct ggml_tensor * ggml_map_custom3_impl(
4845
4993
  struct ggml_context * ctx,
4846
4994
  struct ggml_tensor * a,
4847
- const ggml_custom1_op_t fun,
4995
+ struct ggml_tensor * b,
4996
+ struct ggml_tensor * c,
4997
+ const ggml_custom3_op_t fun,
4848
4998
  int n_tasks,
4849
4999
  void * userdata,
4850
5000
  bool inplace) {
@@ -4852,137 +5002,103 @@ static struct ggml_tensor * ggml_map_custom1_impl(
4852
5002
 
4853
5003
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4854
5004
 
4855
- struct ggml_map_custom1_op_params params = {
5005
+ struct ggml_map_custom3_op_params params = {
4856
5006
  /*.fun =*/ fun,
4857
5007
  /*.n_tasks =*/ n_tasks,
4858
5008
  /*.userdata =*/ userdata
4859
5009
  };
4860
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
5010
+ ggml_set_op_params(result, &params, sizeof(params));
4861
5011
 
4862
- result->op = GGML_OP_MAP_CUSTOM1;
5012
+ result->op = GGML_OP_MAP_CUSTOM3;
4863
5013
  result->src[0] = a;
5014
+ result->src[1] = b;
5015
+ result->src[2] = c;
4864
5016
 
4865
5017
  return result;
4866
5018
  }
4867
5019
 
4868
- struct ggml_tensor * ggml_map_custom1(
5020
+ struct ggml_tensor * ggml_map_custom3(
4869
5021
  struct ggml_context * ctx,
4870
5022
  struct ggml_tensor * a,
4871
- const ggml_custom1_op_t fun,
5023
+ struct ggml_tensor * b,
5024
+ struct ggml_tensor * c,
5025
+ const ggml_custom3_op_t fun,
4872
5026
  int n_tasks,
4873
5027
  void * userdata) {
4874
- return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
5028
+ return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
4875
5029
  }
4876
5030
 
4877
- struct ggml_tensor * ggml_map_custom1_inplace(
5031
+ struct ggml_tensor * ggml_map_custom3_inplace(
4878
5032
  struct ggml_context * ctx,
4879
5033
  struct ggml_tensor * a,
4880
- const ggml_custom1_op_t fun,
5034
+ struct ggml_tensor * b,
5035
+ struct ggml_tensor * c,
5036
+ const ggml_custom3_op_t fun,
4881
5037
  int n_tasks,
4882
5038
  void * userdata) {
4883
- return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
5039
+ return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
4884
5040
  }
4885
5041
 
4886
- // ggml_map_custom2
5042
+ struct ggml_tensor * ggml_custom_4d(
5043
+ struct ggml_context * ctx,
5044
+ enum ggml_type type,
5045
+ int64_t ne0,
5046
+ int64_t ne1,
5047
+ int64_t ne2,
5048
+ int64_t ne3,
5049
+ struct ggml_tensor ** args,
5050
+ int n_args,
5051
+ ggml_custom_op_t fun,
5052
+ int n_tasks,
5053
+ void * userdata) {
4887
5054
 
4888
- static struct ggml_tensor * ggml_map_custom2_impl(
4889
- struct ggml_context * ctx,
4890
- struct ggml_tensor * a,
4891
- struct ggml_tensor * b,
4892
- const ggml_custom2_op_t fun,
4893
- int n_tasks,
4894
- void * userdata,
4895
- bool inplace) {
4896
- GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
5055
+ GGML_ASSERT(n_args < GGML_MAX_SRC);
4897
5056
 
4898
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5057
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
4899
5058
 
4900
- struct ggml_map_custom2_op_params params = {
5059
+ struct ggml_custom_op_params params = {
4901
5060
  /*.fun =*/ fun,
4902
5061
  /*.n_tasks =*/ n_tasks,
4903
5062
  /*.userdata =*/ userdata
4904
5063
  };
4905
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
5064
+ ggml_set_op_params(result, &params, sizeof(params));
4906
5065
 
4907
- result->op = GGML_OP_MAP_CUSTOM2;
4908
- result->src[0] = a;
4909
- result->src[1] = b;
5066
+ result->op = GGML_OP_CUSTOM;
5067
+ for (int i = 0; i < n_args; i++) {
5068
+ result->src[i] = args[i];
5069
+ }
4910
5070
 
4911
5071
  return result;
4912
5072
  }
4913
5073
 
4914
- struct ggml_tensor * ggml_map_custom2(
4915
- struct ggml_context * ctx,
4916
- struct ggml_tensor * a,
4917
- struct ggml_tensor * b,
4918
- const ggml_custom2_op_t fun,
4919
- int n_tasks,
4920
- void * userdata) {
4921
- return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
4922
- }
4923
-
4924
- struct ggml_tensor * ggml_map_custom2_inplace(
4925
- struct ggml_context * ctx,
4926
- struct ggml_tensor * a,
4927
- struct ggml_tensor * b,
4928
- const ggml_custom2_op_t fun,
4929
- int n_tasks,
4930
- void * userdata) {
4931
- return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
4932
- }
4933
-
4934
- // ggml_map_custom3
5074
+ struct ggml_tensor * ggml_custom_inplace(
5075
+ struct ggml_context * ctx,
5076
+ struct ggml_tensor * a,
5077
+ struct ggml_tensor ** args,
5078
+ int n_args,
5079
+ ggml_custom_op_t fun,
5080
+ int n_tasks,
5081
+ void * userdata) {
4935
5082
 
4936
- static struct ggml_tensor * ggml_map_custom3_impl(
4937
- struct ggml_context * ctx,
4938
- struct ggml_tensor * a,
4939
- struct ggml_tensor * b,
4940
- struct ggml_tensor * c,
4941
- const ggml_custom3_op_t fun,
4942
- int n_tasks,
4943
- void * userdata,
4944
- bool inplace) {
4945
- GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
5083
+ GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
4946
5084
 
4947
- struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5085
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
4948
5086
 
4949
- struct ggml_map_custom3_op_params params = {
5087
+ struct ggml_custom_op_params params = {
4950
5088
  /*.fun =*/ fun,
4951
5089
  /*.n_tasks =*/ n_tasks,
4952
5090
  /*.userdata =*/ userdata
4953
5091
  };
4954
- ggml_set_op_params(result, (const void *) &params, sizeof(params));
5092
+ ggml_set_op_params(result, &params, sizeof(params));
4955
5093
 
4956
- result->op = GGML_OP_MAP_CUSTOM3;
5094
+ result->op = GGML_OP_CUSTOM;
4957
5095
  result->src[0] = a;
4958
- result->src[1] = b;
4959
- result->src[2] = c;
5096
+ for (int i = 0; i < n_args; i++) {
5097
+ result->src[i + 1] = args[i];
5098
+ }
4960
5099
 
4961
5100
  return result;
4962
5101
  }
4963
-
4964
- struct ggml_tensor * ggml_map_custom3(
4965
- struct ggml_context * ctx,
4966
- struct ggml_tensor * a,
4967
- struct ggml_tensor * b,
4968
- struct ggml_tensor * c,
4969
- const ggml_custom3_op_t fun,
4970
- int n_tasks,
4971
- void * userdata) {
4972
- return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
4973
- }
4974
-
4975
- struct ggml_tensor * ggml_map_custom3_inplace(
4976
- struct ggml_context * ctx,
4977
- struct ggml_tensor * a,
4978
- struct ggml_tensor * b,
4979
- struct ggml_tensor * c,
4980
- const ggml_custom3_op_t fun,
4981
- int n_tasks,
4982
- void * userdata) {
4983
- return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
4984
- }
4985
-
4986
5102
  // ggml_cross_entropy_loss
4987
5103
 
4988
5104
  struct ggml_tensor * ggml_cross_entropy_loss(
@@ -5007,10 +5123,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
5007
5123
  struct ggml_tensor * a,
5008
5124
  struct ggml_tensor * b,
5009
5125
  struct ggml_tensor * c) {
5010
- GGML_ASSERT(ggml_are_same_shape(a, b));
5011
- GGML_ASSERT(ggml_is_scalar(c));
5126
+ GGML_ASSERT(ggml_is_scalar(a));
5127
+ GGML_ASSERT(ggml_are_same_shape(b, c));
5012
5128
 
5013
- struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
5129
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
5014
5130
 
5015
5131
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
5016
5132
  result->src[0] = a;
@@ -5189,7 +5305,7 @@ static void ggml_sub_or_set(
5189
5305
  }
5190
5306
 
5191
5307
  static void ggml_compute_backward(
5192
- struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
5308
+ struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
5193
5309
  struct ggml_tensor * tensor = cgraph->nodes[i];
5194
5310
  struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
5195
5311
 
@@ -5261,7 +5377,7 @@ static void ggml_compute_backward(
5261
5377
  } break;
5262
5378
  case GGML_OP_MUL: {
5263
5379
  if (src0_needs_grads) {
5264
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
5380
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
5265
5381
  }
5266
5382
  if (src1_needs_grads) {
5267
5383
  struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5333,7 +5449,7 @@ static void ggml_compute_backward(
5333
5449
  if (src0_needs_grads) {
5334
5450
  float eps;
5335
5451
  memcpy(&eps, tensor->op_params, sizeof(float));
5336
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
5452
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
5337
5453
  }
5338
5454
  } break;
5339
5455
  case GGML_OP_MUL_MAT: {
@@ -5353,21 +5469,25 @@ static void ggml_compute_backward(
5353
5469
  // src1.shape [n,p,qq,rr]
5354
5470
 
5355
5471
  if (src0_needs_grads) {
5356
- struct ggml_tensor * s1_tg =
5472
+ GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5473
+ GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5474
+ struct ggml_tensor * tmp =
5357
5475
  ggml_out_prod(ctx, // [n,m,qq,rr]
5358
5476
  src1, // [n,p,qq,rr]
5359
5477
  grad); // [m,p,qq,rr]
5360
- const int64_t qq = s1_tg->ne[2];
5361
- const int64_t rr = s1_tg->ne[3];
5362
- const int64_t q1 = src0->ne[2];
5363
- const int64_t r1 = src0->ne[3];
5364
- const bool ne2_broadcasted = qq > q1;
5365
- const bool ne3_broadcasted = rr > r1;
5366
- if (ne2_broadcasted || ne3_broadcasted) {
5367
- // sum broadcast repetitions of s1_tg into shape of src0
5368
- s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5478
+ if (!ggml_are_same_shape(tmp, src0)) {
5479
+ GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5480
+ GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5481
+ GGML_ASSERT(tmp->ne[3] == 1);
5482
+
5483
+ const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5484
+ const size_t nb2 = tmp->nb[2] * nr2;
5485
+ const size_t nb3 = tmp->nb[2];
5486
+
5487
+ tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5488
+ tmp = ggml_repeat_back(ctx, tmp, src0);
5369
5489
  }
5370
- ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5490
+ ggml_add_or_set(ctx, cgraph, isrc0, tmp);
5371
5491
  }
5372
5492
  if (src1_needs_grads) {
5373
5493
  ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5425,7 +5545,7 @@ static void ggml_compute_backward(
5425
5545
  // tensor = src0 * 1 + src1 * 0
5426
5546
  if (src0_needs_grads) {
5427
5547
  // dsrc0 = dtensor * 1
5428
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
5548
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
5429
5549
  }
5430
5550
  if (src1_needs_grads) {
5431
5551
  // dsrc1 = dtensor * 0 -> noop
@@ -5436,7 +5556,9 @@ static void ggml_compute_backward(
5436
5556
  if (src0_needs_grads) {
5437
5557
  GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
5438
5558
  GGML_ASSERT(ggml_is_contiguous(grad));
5439
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
5559
+ GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
5560
+ ggml_add_or_set(ctx, cgraph, isrc0,
5561
+ ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
5440
5562
  }
5441
5563
  } break;
5442
5564
  case GGML_OP_RESHAPE: {
@@ -5516,7 +5638,13 @@ static void ggml_compute_backward(
5516
5638
  } break;
5517
5639
  case GGML_OP_SOFT_MAX: {
5518
5640
  if (src0_needs_grads) {
5519
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
5641
+ float scale = 1.0f;
5642
+ float max_bias = 0.0f;
5643
+
5644
+ memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
5645
+ memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
5646
+
5647
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
5520
5648
  }
5521
5649
  GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5522
5650
  } break;
@@ -5528,6 +5656,7 @@ static void ggml_compute_backward(
5528
5656
  //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5529
5657
  const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5530
5658
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5659
+ int sections[4] = {0, 0, 0, 0};
5531
5660
 
5532
5661
  memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5533
5662
  memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@@ -5535,10 +5664,14 @@ static void ggml_compute_backward(
5535
5664
  memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5536
5665
  memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5537
5666
  memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5538
-
5539
- ggml_add_or_set(ctx, cgraph, isrc0,
5540
- ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
5541
- freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
5667
+ memcpy(&sections, tensor->op_params + 11, sizeof(sections));
5668
+
5669
+ struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
5670
+ ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
5671
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
5672
+ ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
5673
+ mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
5674
+ ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
5542
5675
  }
5543
5676
  GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5544
5677
  } break;
@@ -5552,7 +5685,7 @@ static void ggml_compute_backward(
5552
5685
  const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
5553
5686
  const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
5554
5687
 
5555
- ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5688
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5556
5689
  }
5557
5690
  } break;
5558
5691
  case GGML_OP_POOL_2D: {
@@ -5595,7 +5728,7 @@ static void ggml_compute_backward(
5595
5728
  } break;
5596
5729
  case GGML_UNARY_OP_SILU: {
5597
5730
  if (src0_needs_grads) {
5598
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
5731
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
5599
5732
  }
5600
5733
  } break;
5601
5734
  case GGML_UNARY_OP_EXP: {
@@ -5612,7 +5745,7 @@ static void ggml_compute_backward(
5612
5745
  } break;
5613
5746
  case GGML_OP_CROSS_ENTROPY_LOSS: {
5614
5747
  if (src0_needs_grads) {
5615
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
5748
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
5616
5749
  }
5617
5750
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5618
5751
  } break;
@@ -5693,10 +5826,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
5693
5826
  }
5694
5827
 
5695
5828
  void ggml_build_backward_expand(
5696
- struct ggml_context * ctx_static,
5697
- struct ggml_context * ctx_compute,
5698
- struct ggml_cgraph * cgraph,
5699
- bool accumulate) {
5829
+ struct ggml_context * ctx,
5830
+ struct ggml_cgraph * cgraph,
5831
+ struct ggml_tensor ** grad_accs) {
5700
5832
  GGML_ASSERT(cgraph->n_nodes > 0);
5701
5833
  GGML_ASSERT(cgraph->grads);
5702
5834
  GGML_ASSERT(cgraph->grad_accs);
@@ -5769,21 +5901,24 @@ void ggml_build_backward_expand(
5769
5901
  GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
5770
5902
  node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
5771
5903
 
5772
- const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5773
- GGML_ASSERT(igrad != GGML_HASHSET_FULL);
5774
- GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5775
- if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
5776
- cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
5777
- cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5778
- ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
5904
+ const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
5905
+ GGML_ASSERT(ihash != GGML_HASHSET_FULL);
5906
+ GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
5907
+ if (grad_accs && grad_accs[i]) {
5908
+ cgraph->grad_accs[ihash] = grad_accs[i];
5909
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5910
+ } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
5911
+ // loss tensors always need a gradient accumulator
5912
+ cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
5913
+ cgraph->grads[ihash] = cgraph->grad_accs[ihash];
5779
5914
  }
5780
- grads_needed[igrad] = true;
5915
+ grads_needed[ihash] = true;
5781
5916
  }
5782
5917
 
5783
5918
  for (int i = n_nodes_f - 1; i >= 0; --i) {
5784
5919
  // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
5785
5920
  // use allocator to automatically make inplace operations
5786
- ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
5921
+ ggml_compute_backward(ctx, cgraph, i, grads_needed);
5787
5922
  }
5788
5923
 
5789
5924
  free(grads_needed);
@@ -5929,8 +6064,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
5929
6064
  }
5930
6065
  }
5931
6066
 
5932
- struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
5933
- struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
6067
+ struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
6068
+ struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
5934
6069
  ggml_graph_cpy(cgraph, result);
5935
6070
  return result;
5936
6071
  }
@@ -5949,6 +6084,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
5949
6084
  }
5950
6085
 
5951
6086
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
6087
+ if (!cgraph) {
6088
+ return;
6089
+ }
5952
6090
  GGML_ASSERT(cgraph->grads != NULL);
5953
6091
 
5954
6092
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6258,8 +6396,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
6258
6396
  tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
6259
6397
  }
6260
6398
 
6261
- void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
6262
- GGML_UNUSED(ctx); // TODO: remove this parameter
6399
+ void ggml_set_param(struct ggml_tensor * tensor) {
6400
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
6263
6401
  tensor->flags |= GGML_TENSOR_FLAG_PARAM;
6264
6402
  }
6265
6403
 
@@ -6383,1288 +6521,6 @@ size_t ggml_quantize_chunk(
6383
6521
 
6384
6522
  ////////////////////////////////////////////////////////////////////////////////
6385
6523
 
6386
- struct gguf_str {
6387
- uint64_t n; // GGUFv2
6388
- char * data;
6389
- };
6390
-
6391
- static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
6392
- [GGUF_TYPE_UINT8] = sizeof(uint8_t),
6393
- [GGUF_TYPE_INT8] = sizeof(int8_t),
6394
- [GGUF_TYPE_UINT16] = sizeof(uint16_t),
6395
- [GGUF_TYPE_INT16] = sizeof(int16_t),
6396
- [GGUF_TYPE_UINT32] = sizeof(uint32_t),
6397
- [GGUF_TYPE_INT32] = sizeof(int32_t),
6398
- [GGUF_TYPE_FLOAT32] = sizeof(float),
6399
- [GGUF_TYPE_BOOL] = sizeof(bool),
6400
- [GGUF_TYPE_STRING] = sizeof(struct gguf_str),
6401
- [GGUF_TYPE_UINT64] = sizeof(uint64_t),
6402
- [GGUF_TYPE_INT64] = sizeof(int64_t),
6403
- [GGUF_TYPE_FLOAT64] = sizeof(double),
6404
- [GGUF_TYPE_ARRAY] = 0, // undefined
6405
- };
6406
- static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
6407
-
6408
- static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
6409
- [GGUF_TYPE_UINT8] = "u8",
6410
- [GGUF_TYPE_INT8] = "i8",
6411
- [GGUF_TYPE_UINT16] = "u16",
6412
- [GGUF_TYPE_INT16] = "i16",
6413
- [GGUF_TYPE_UINT32] = "u32",
6414
- [GGUF_TYPE_INT32] = "i32",
6415
- [GGUF_TYPE_FLOAT32] = "f32",
6416
- [GGUF_TYPE_BOOL] = "bool",
6417
- [GGUF_TYPE_STRING] = "str",
6418
- [GGUF_TYPE_ARRAY] = "arr",
6419
- [GGUF_TYPE_UINT64] = "u64",
6420
- [GGUF_TYPE_INT64] = "i64",
6421
- [GGUF_TYPE_FLOAT64] = "f64",
6422
- };
6423
- static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
6424
-
6425
- union gguf_value {
6426
- uint8_t uint8;
6427
- int8_t int8;
6428
- uint16_t uint16;
6429
- int16_t int16;
6430
- uint32_t uint32;
6431
- int32_t int32;
6432
- float float32;
6433
- uint64_t uint64;
6434
- int64_t int64;
6435
- double float64;
6436
- bool bool_;
6437
-
6438
- struct gguf_str str;
6439
-
6440
- struct {
6441
- enum gguf_type type;
6442
-
6443
- uint64_t n; // GGUFv2
6444
- void * data;
6445
- } arr;
6446
- };
6447
-
6448
- struct gguf_kv {
6449
- struct gguf_str key;
6450
-
6451
- enum gguf_type type;
6452
- union gguf_value value;
6453
- };
6454
-
6455
- struct gguf_header {
6456
- char magic[4];
6457
-
6458
- uint32_t version;
6459
- uint64_t n_tensors; // GGUFv2
6460
- uint64_t n_kv; // GGUFv2
6461
- };
6462
-
6463
- struct gguf_tensor_info {
6464
- struct gguf_str name;
6465
-
6466
- uint32_t n_dims;
6467
- uint64_t ne[GGML_MAX_DIMS];
6468
-
6469
- enum ggml_type type;
6470
-
6471
- uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
6472
-
6473
- // for writing API
6474
- const void * data;
6475
- size_t size;
6476
- };
6477
-
6478
- struct gguf_context {
6479
- struct gguf_header header;
6480
-
6481
- struct gguf_kv * kv;
6482
- struct gguf_tensor_info * infos;
6483
-
6484
- size_t alignment;
6485
- size_t offset; // offset of `data` from beginning of file
6486
- size_t size; // size of `data` in bytes
6487
-
6488
- //uint8_t * padding;
6489
- void * data;
6490
- };
6491
-
6492
- static size_t gguf_type_size(enum gguf_type type) {
6493
- GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
6494
- return GGUF_TYPE_SIZE[type];
6495
- }
6496
-
6497
- static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
6498
- if (info->n_dims > GGML_MAX_DIMS) {
6499
- fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
6500
- return false;
6501
- }
6502
-
6503
- if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
6504
- fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
6505
- return false;
6506
- }
6507
-
6508
- if (strlen(info->name.data) >= GGML_MAX_NAME) {
6509
- fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
6510
- return false;
6511
- }
6512
-
6513
- for (uint32_t i = 0; i < info->n_dims; ++i) {
6514
- if (info->ne[i] <= 0) {
6515
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
6516
- return false;
6517
- }
6518
- }
6519
-
6520
- // prevent overflow for total number of elements
6521
- if (INT64_MAX/info->ne[1] <= info->ne[0]) {
6522
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
6523
- return false;
6524
- }
6525
-
6526
- if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
6527
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
6528
- return false;
6529
- }
6530
-
6531
- if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
6532
- fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
6533
- return false;
6534
- }
6535
-
6536
- return true;
6537
- }
6538
-
6539
- static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
6540
- const size_t n = fread(dst, 1, size, file);
6541
- *offset += n;
6542
- return n == size;
6543
- }
6544
-
6545
- static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
6546
- p->n = 0;
6547
- p->data = NULL;
6548
-
6549
- bool ok = true;
6550
-
6551
- ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
6552
-
6553
- // early exit if string length is invalid, prevents from integer overflow
6554
- if (p->n == SIZE_MAX) {
6555
- fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
6556
- return false;
6557
- }
6558
-
6559
- p->data = calloc(p->n + 1, 1);
6560
- if (!p->data) {
6561
- fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
6562
- return false;
6563
- }
6564
-
6565
- ok = ok && gguf_fread_el(file, p->data, p->n, offset);
6566
-
6567
- return ok;
6568
- }
6569
-
6570
- static void gguf_free_kv(struct gguf_kv * kv) {
6571
- if (kv->key.data) {
6572
- GGML_FREE(kv->key.data);
6573
- }
6574
-
6575
- if (kv->type == GGUF_TYPE_STRING) {
6576
- if (kv->value.str.data) {
6577
- GGML_FREE(kv->value.str.data);
6578
- }
6579
- }
6580
-
6581
- if (kv->type == GGUF_TYPE_ARRAY) {
6582
- if (kv->value.arr.data) {
6583
- if (kv->value.arr.type == GGUF_TYPE_STRING) {
6584
- for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
6585
- struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
6586
- if (str->data) {
6587
- GGML_FREE(str->data);
6588
- }
6589
- }
6590
- }
6591
- GGML_FREE(kv->value.arr.data);
6592
- }
6593
- }
6594
- }
6595
-
6596
- struct gguf_context * gguf_init_empty(void) {
6597
- struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
6598
- if (!ctx) {
6599
- fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
6600
- return NULL;
6601
- }
6602
-
6603
- memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
6604
- ctx->header.version = GGUF_VERSION;
6605
- ctx->header.n_tensors = 0;
6606
- ctx->header.n_kv = 0;
6607
-
6608
- ctx->kv = NULL;
6609
- ctx->infos = NULL;
6610
-
6611
- ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
6612
- ctx->offset = 0;
6613
- ctx->size = 0;
6614
-
6615
- ctx->data = NULL;
6616
-
6617
- return ctx;
6618
- }
6619
-
6620
- struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
6621
- FILE * file = ggml_fopen(fname, "rb");
6622
- if (!file) {
6623
- fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
6624
- return NULL;
6625
- }
6626
-
6627
- // offset from start of file
6628
- size_t offset = 0;
6629
-
6630
- char magic[4];
6631
-
6632
- // check the magic before making allocations
6633
- {
6634
- gguf_fread_el(file, &magic, sizeof(magic), &offset);
6635
-
6636
- for (uint32_t i = 0; i < sizeof(magic); i++) {
6637
- if (magic[i] != GGUF_MAGIC[i]) {
6638
- fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
6639
- fclose(file);
6640
- return NULL;
6641
- }
6642
- }
6643
- }
6644
-
6645
- bool ok = true;
6646
-
6647
- struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
6648
- if (!ctx) {
6649
- fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
6650
- fclose(file);
6651
- return NULL;
6652
- }
6653
-
6654
- // read the header
6655
- {
6656
- strncpy(ctx->header.magic, magic, 4);
6657
-
6658
- ctx->kv = NULL;
6659
- ctx->infos = NULL;
6660
- ctx->data = NULL;
6661
-
6662
- ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
6663
- ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
6664
- ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
6665
-
6666
- if (ctx->header.version == 1) {
6667
- fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
6668
- fclose(file);
6669
- gguf_free(ctx);
6670
- return NULL;
6671
- }
6672
-
6673
- // sanity-checks to prevent from integer/buffer overflows
6674
-
6675
- ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
6676
- ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
6677
- ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct gguf_kv));
6678
-
6679
- if (!ok) {
6680
- fprintf(stderr, "%s: failed to read header\n", __func__);
6681
- fclose(file);
6682
- gguf_free(ctx);
6683
- return NULL;
6684
- }
6685
- }
6686
-
6687
- // read the kv pairs
6688
- {
6689
- const uint64_t n_kv = ctx->header.n_kv;
6690
-
6691
- ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
6692
- if (!ctx->kv) {
6693
- fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
6694
- fclose(file);
6695
- gguf_free(ctx);
6696
- return NULL;
6697
- }
6698
-
6699
- for (uint64_t i = 0; i < n_kv; ++i) {
6700
- struct gguf_kv * kv = &ctx->kv[i];
6701
-
6702
- //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
6703
-
6704
- ok = ok && gguf_fread_str(file, &kv->key, &offset);
6705
- ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
6706
-
6707
- //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
6708
-
6709
- switch (kv->type) {
6710
- case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break;
6711
- case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break;
6712
- case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break;
6713
- case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break;
6714
- case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break;
6715
- case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break;
6716
- case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
6717
- case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break;
6718
- case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break;
6719
- case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
6720
- case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break;
6721
- case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break;
6722
- case GGUF_TYPE_ARRAY:
6723
- {
6724
- ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
6725
- ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
6726
-
6727
- switch (kv->value.arr.type) {
6728
- case GGUF_TYPE_UINT8:
6729
- case GGUF_TYPE_INT8:
6730
- case GGUF_TYPE_UINT16:
6731
- case GGUF_TYPE_INT16:
6732
- case GGUF_TYPE_UINT32:
6733
- case GGUF_TYPE_INT32:
6734
- case GGUF_TYPE_FLOAT32:
6735
- case GGUF_TYPE_UINT64:
6736
- case GGUF_TYPE_INT64:
6737
- case GGUF_TYPE_FLOAT64:
6738
- case GGUF_TYPE_BOOL:
6739
- {
6740
- // prevent from integer overflow in the malloc below
6741
- if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
6742
- fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
6743
- fclose(file);
6744
- gguf_free(ctx);
6745
- return NULL;
6746
- }
6747
-
6748
- kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
6749
- if (!kv->value.arr.data) {
6750
- fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
6751
- fclose(file);
6752
- gguf_free(ctx);
6753
- return NULL;
6754
- }
6755
-
6756
- ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
6757
- } break;
6758
- case GGUF_TYPE_STRING:
6759
- {
6760
- // prevent from integer overflow in the malloc below
6761
- if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
6762
- fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
6763
- fclose(file);
6764
- gguf_free(ctx);
6765
- return NULL;
6766
- }
6767
-
6768
- kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
6769
- if (!kv->value.arr.data) {
6770
- fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
6771
- fclose(file);
6772
- gguf_free(ctx);
6773
- return NULL;
6774
- }
6775
-
6776
- for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
6777
- ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
6778
- }
6779
- } break;
6780
- case GGUF_TYPE_ARRAY:
6781
- default:
6782
- {
6783
- fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
6784
- ok = false;
6785
- } break;
6786
- }
6787
- } break;
6788
- default:
6789
- {
6790
- fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
6791
- ok = false;
6792
- } break;
6793
- }
6794
-
6795
- if (!ok) {
6796
- break;
6797
- }
6798
- }
6799
-
6800
- if (!ok) {
6801
- fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
6802
- fclose(file);
6803
- gguf_free(ctx);
6804
- return NULL;
6805
- }
6806
- }
6807
-
6808
- // read the tensor infos
6809
- if (ctx->header.n_tensors > 0) {
6810
- ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
6811
- if (!ctx->infos) {
6812
- fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
6813
- fclose(file);
6814
- gguf_free(ctx);
6815
- return NULL;
6816
- }
6817
-
6818
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6819
- struct gguf_tensor_info * info = &ctx->infos[i];
6820
-
6821
- for (int j = 0; j < GGML_MAX_DIMS; ++j) {
6822
- info->ne[j] = 1;
6823
- }
6824
-
6825
- ok = ok && gguf_fread_str(file, &info->name, &offset);
6826
- ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
6827
-
6828
- ok = ok && (info->n_dims <= GGML_MAX_DIMS);
6829
-
6830
- for (uint32_t j = 0; j < info->n_dims; ++j) {
6831
- ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
6832
- }
6833
-
6834
- ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
6835
- ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
6836
-
6837
- ok = ok && gguf_tensor_info_sanitize(info);
6838
-
6839
- // make sure there is no duplicated tensor names
6840
- for (uint64_t j = 0; j < i && ok; ++j) {
6841
- if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
6842
- fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
6843
- ok = false;
6844
- }
6845
- }
6846
-
6847
- if (!ok) {
6848
- fprintf(stderr, "%s: failed to read tensor info\n", __func__);
6849
- fclose(file);
6850
- gguf_free(ctx);
6851
- return NULL;
6852
- }
6853
- }
6854
- }
6855
-
6856
- ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
6857
-
6858
- int alignment_idx = gguf_find_key(ctx, "general.alignment");
6859
- if (alignment_idx != -1) {
6860
- ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
6861
- }
6862
-
6863
- // we require the data section to be aligned, so take into account any padding
6864
- {
6865
- const size_t offset_pad = offset % ctx->alignment;
6866
-
6867
- if (offset_pad != 0) {
6868
- offset += ctx->alignment - offset_pad;
6869
- fseek(file, offset, SEEK_SET);
6870
- }
6871
- }
6872
-
6873
- // store the current file offset - this is where the data section starts
6874
- ctx->offset = offset;
6875
-
6876
- // compute the total size of the data section, taking into account the alignment
6877
- {
6878
- ctx->size = 0;
6879
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6880
- struct gguf_tensor_info * info = &ctx->infos[i];
6881
-
6882
- const int64_t ne =
6883
- (int64_t) info->ne[0] *
6884
- (int64_t) info->ne[1] *
6885
- (int64_t) info->ne[2] *
6886
- (int64_t) info->ne[3];
6887
-
6888
- if (ggml_blck_size(info->type) == 0 ) {
6889
- // this tensor type support have been removed:
6890
- fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
6891
- __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
6892
- fclose(file);
6893
- gguf_free(ctx);
6894
- return NULL;
6895
- }
6896
-
6897
- if (ne % ggml_blck_size(info->type) != 0) {
6898
- fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
6899
- __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
6900
- fclose(file);
6901
- gguf_free(ctx);
6902
- return NULL;
6903
- }
6904
-
6905
- const size_t size_cur = ggml_row_size(info->type, ne);
6906
-
6907
- ctx->size += GGML_PAD(size_cur, ctx->alignment);
6908
- }
6909
- }
6910
-
6911
- // load the tensor data only if requested
6912
- if (params.ctx != NULL) {
6913
- // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
6914
- // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
6915
- // the ggml_tensor structs to the appropriate locations in the binary blob
6916
-
6917
- // compute the exact size needed for the new ggml_context
6918
- const size_t mem_size =
6919
- params.no_alloc ?
6920
- (ctx->header.n_tensors )*ggml_tensor_overhead() :
6921
- (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
6922
-
6923
- struct ggml_init_params pdata = {
6924
- .mem_size = mem_size,
6925
- .mem_buffer = NULL,
6926
- .no_alloc = params.no_alloc,
6927
- };
6928
-
6929
- *params.ctx = ggml_init(pdata);
6930
- if (*params.ctx == NULL) {
6931
- fprintf(stderr, "%s: failed to initialize context\n", __func__);
6932
- fclose(file);
6933
- gguf_free(ctx);
6934
- return NULL;
6935
- }
6936
-
6937
- struct ggml_context * ctx_data = *params.ctx;
6938
-
6939
- struct ggml_tensor * data = NULL;
6940
-
6941
- if (!params.no_alloc) {
6942
- data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
6943
-
6944
- ok = ok && data != NULL;
6945
-
6946
- // read the binary blob with the tensor data
6947
- ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
6948
-
6949
- if (!ok) {
6950
- fprintf(stderr, "%s: failed to read tensor data\n", __func__);
6951
- fclose(file);
6952
- ggml_free(ctx_data);
6953
- gguf_free(ctx);
6954
- return NULL;
6955
- }
6956
-
6957
- ctx->data = data->data;
6958
- }
6959
-
6960
- ggml_set_no_alloc(ctx_data, true);
6961
-
6962
- // create the tensors
6963
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
6964
- const int64_t ne[GGML_MAX_DIMS] = {
6965
- ctx->infos[i].ne[0],
6966
- ctx->infos[i].ne[1],
6967
- ctx->infos[i].ne[2],
6968
- ctx->infos[i].ne[3],
6969
- };
6970
-
6971
- struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
6972
-
6973
- ok = ok && cur != NULL;
6974
-
6975
- if (!ok) {
6976
- break;
6977
- }
6978
-
6979
- ggml_set_name(cur, ctx->infos[i].name.data);
6980
-
6981
- // point the data member to the appropriate location in the binary blob using the tensor infos
6982
- if (!params.no_alloc) {
6983
- //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
6984
- cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
6985
- }
6986
- }
6987
-
6988
- if (!ok) {
6989
- fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
6990
- fclose(file);
6991
- ggml_free(ctx_data);
6992
- gguf_free(ctx);
6993
- return NULL;
6994
- }
6995
-
6996
- ggml_set_no_alloc(ctx_data, params.no_alloc);
6997
- }
6998
-
6999
- fclose(file);
7000
-
7001
- return ctx;
7002
- }
7003
-
7004
- void gguf_free(struct gguf_context * ctx) {
7005
- if (ctx == NULL) {
7006
- return;
7007
- }
7008
-
7009
- if (ctx->kv) {
7010
- // free string memory - not great..
7011
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
7012
- gguf_free_kv(&ctx->kv[i]);
7013
- }
7014
-
7015
- GGML_FREE(ctx->kv);
7016
- }
7017
-
7018
- if (ctx->infos) {
7019
- for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
7020
- struct gguf_tensor_info * info = &ctx->infos[i];
7021
-
7022
- if (info->name.data) {
7023
- GGML_FREE(info->name.data);
7024
- }
7025
- }
7026
-
7027
- GGML_FREE(ctx->infos);
7028
- }
7029
-
7030
- GGML_FREE(ctx);
7031
- }
7032
-
7033
- const char * gguf_type_name(enum gguf_type type) {
7034
- return GGUF_TYPE_NAME[type];
7035
- }
7036
-
7037
- int gguf_get_version(const struct gguf_context * ctx) {
7038
- return ctx->header.version;
7039
- }
7040
-
7041
- size_t gguf_get_alignment(const struct gguf_context * ctx) {
7042
- return ctx->alignment;
7043
- }
7044
-
7045
- size_t gguf_get_data_offset(const struct gguf_context * ctx) {
7046
- return ctx->offset;
7047
- }
7048
-
7049
- void * gguf_get_data(const struct gguf_context * ctx) {
7050
- return ctx->data;
7051
- }
7052
-
7053
- int gguf_get_n_kv(const struct gguf_context * ctx) {
7054
- return ctx->header.n_kv;
7055
- }
7056
-
7057
- int gguf_find_key(const struct gguf_context * ctx, const char * key) {
7058
- // return -1 if key not found
7059
- int keyfound = -1;
7060
-
7061
- const int n_kv = gguf_get_n_kv(ctx);
7062
-
7063
- for (int i = 0; i < n_kv; ++i) {
7064
- if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
7065
- keyfound = i;
7066
- break;
7067
- }
7068
- }
7069
-
7070
- return keyfound;
7071
- }
7072
-
7073
- const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
7074
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7075
- return ctx->kv[key_id].key.data;
7076
- }
7077
-
7078
- enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
7079
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7080
- return ctx->kv[key_id].type;
7081
- }
7082
-
7083
- enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
7084
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7085
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7086
- return ctx->kv[key_id].value.arr.type;
7087
- }
7088
-
7089
- const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
7090
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7091
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7092
- return ctx->kv[key_id].value.arr.data;
7093
- }
7094
-
7095
- const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
7096
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7097
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7098
- struct gguf_kv * kv = &ctx->kv[key_id];
7099
- struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
7100
- return str->data;
7101
- }
7102
-
7103
- int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
7104
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7105
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
7106
- return ctx->kv[key_id].value.arr.n;
7107
- }
7108
-
7109
- uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
7110
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7111
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
7112
- return ctx->kv[key_id].value.uint8;
7113
- }
7114
-
7115
- int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
7116
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7117
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
7118
- return ctx->kv[key_id].value.int8;
7119
- }
7120
-
7121
- uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
7122
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7123
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
7124
- return ctx->kv[key_id].value.uint16;
7125
- }
7126
-
7127
- int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
7128
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7129
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
7130
- return ctx->kv[key_id].value.int16;
7131
- }
7132
-
7133
- uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
7134
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7135
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
7136
- return ctx->kv[key_id].value.uint32;
7137
- }
7138
-
7139
- int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
7140
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7141
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
7142
- return ctx->kv[key_id].value.int32;
7143
- }
7144
-
7145
- float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
7146
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7147
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
7148
- return ctx->kv[key_id].value.float32;
7149
- }
7150
-
7151
- uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
7152
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7153
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
7154
- return ctx->kv[key_id].value.uint64;
7155
- }
7156
-
7157
- int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
7158
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7159
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
7160
- return ctx->kv[key_id].value.int64;
7161
- }
7162
-
7163
- double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
7164
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7165
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
7166
- return ctx->kv[key_id].value.float64;
7167
- }
7168
-
7169
- bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
7170
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7171
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
7172
- return ctx->kv[key_id].value.bool_;
7173
- }
7174
-
7175
- const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
7176
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7177
- GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
7178
- return ctx->kv[key_id].value.str.data;
7179
- }
7180
-
7181
- const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
7182
- GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
7183
- GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
7184
- GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
7185
- return &ctx->kv[key_id].value;
7186
- }
7187
-
7188
- int gguf_get_n_tensors(const struct gguf_context * ctx) {
7189
- return ctx->header.n_tensors;
7190
- }
7191
-
7192
- int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
7193
- // return -1 if tensor not found
7194
- int tensorfound = -1;
7195
-
7196
- const int n_tensors = gguf_get_n_tensors(ctx);
7197
-
7198
- for (int i = 0; i < n_tensors; ++i) {
7199
- if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
7200
- tensorfound = i;
7201
- break;
7202
- }
7203
- }
7204
-
7205
- return tensorfound;
7206
- }
7207
-
7208
- size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
7209
- return ctx->infos[i].offset;
7210
- }
7211
-
7212
- char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
7213
- return ctx->infos[i].name.data;
7214
- }
7215
-
7216
- enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
7217
- return ctx->infos[i].type;
7218
- }
7219
-
7220
- // returns the index
7221
- static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
7222
- const int idx = gguf_find_key(ctx, key);
7223
- if (idx >= 0) {
7224
- return idx;
7225
- }
7226
-
7227
- const int n_kv = gguf_get_n_kv(ctx);
7228
-
7229
- ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
7230
- ctx->kv[n_kv].key.n = strlen(key);
7231
- ctx->kv[n_kv].key.data = strdup(key);
7232
- ctx->header.n_kv++;
7233
-
7234
- return n_kv;
7235
- }
7236
-
7237
- void gguf_remove_key(struct gguf_context * ctx, const char * key) {
7238
- const int idx = gguf_find_key(ctx, key);
7239
- if (idx >= 0) {
7240
- const int n_kv = gguf_get_n_kv(ctx);
7241
- gguf_free_kv(&ctx->kv[idx]);
7242
- for (int i = idx; i < n_kv-1; ++i) {
7243
- ctx->kv[i] = ctx->kv[i+1];
7244
- }
7245
- ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
7246
- ctx->header.n_kv--;
7247
- }
7248
- }
7249
-
7250
- void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
7251
- const int idx = gguf_get_or_add_key(ctx, key);
7252
-
7253
- ctx->kv[idx].type = GGUF_TYPE_UINT8;
7254
- ctx->kv[idx].value.uint8 = val;
7255
- }
7256
-
7257
- void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
7258
- const int idx = gguf_get_or_add_key(ctx, key);
7259
-
7260
- ctx->kv[idx].type = GGUF_TYPE_INT8;
7261
- ctx->kv[idx].value.int8 = val;
7262
- }
7263
-
7264
- void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
7265
- const int idx = gguf_get_or_add_key(ctx, key);
7266
-
7267
- ctx->kv[idx].type = GGUF_TYPE_UINT16;
7268
- ctx->kv[idx].value.uint16 = val;
7269
- }
7270
-
7271
- void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
7272
- const int idx = gguf_get_or_add_key(ctx, key);
7273
-
7274
- ctx->kv[idx].type = GGUF_TYPE_INT16;
7275
- ctx->kv[idx].value.int16 = val;
7276
- }
7277
-
7278
- void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
7279
- const int idx = gguf_get_or_add_key(ctx, key);
7280
-
7281
- ctx->kv[idx].type = GGUF_TYPE_UINT32;
7282
- ctx->kv[idx].value.uint32 = val;
7283
- }
7284
-
7285
- void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
7286
- const int idx = gguf_get_or_add_key(ctx, key);
7287
-
7288
- ctx->kv[idx].type = GGUF_TYPE_INT32;
7289
- ctx->kv[idx].value.int32 = val;
7290
- }
7291
-
7292
- void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
7293
- const int idx = gguf_get_or_add_key(ctx, key);
7294
-
7295
- ctx->kv[idx].type = GGUF_TYPE_FLOAT32;
7296
- ctx->kv[idx].value.float32 = val;
7297
- }
7298
-
7299
- void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
7300
- const int idx = gguf_get_or_add_key(ctx, key);
7301
-
7302
- ctx->kv[idx].type = GGUF_TYPE_UINT64;
7303
- ctx->kv[idx].value.uint64 = val;
7304
- }
7305
-
7306
- void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
7307
- const int idx = gguf_get_or_add_key(ctx, key);
7308
-
7309
- ctx->kv[idx].type = GGUF_TYPE_INT64;
7310
- ctx->kv[idx].value.int64 = val;
7311
- }
7312
-
7313
- void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
7314
- const int idx = gguf_get_or_add_key(ctx, key);
7315
-
7316
- ctx->kv[idx].type = GGUF_TYPE_FLOAT64;
7317
- ctx->kv[idx].value.float64 = val;
7318
- }
7319
-
7320
- void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
7321
- const int idx = gguf_get_or_add_key(ctx, key);
7322
-
7323
- ctx->kv[idx].type = GGUF_TYPE_BOOL;
7324
- ctx->kv[idx].value.bool_ = val;
7325
- }
7326
-
7327
- void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
7328
- const int idx = gguf_get_or_add_key(ctx, key);
7329
-
7330
- ctx->kv[idx].type = GGUF_TYPE_STRING;
7331
- ctx->kv[idx].value.str.n = strlen(val);
7332
- ctx->kv[idx].value.str.data = strdup(val);
7333
- }
7334
-
7335
- void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
7336
- const int idx = gguf_get_or_add_key(ctx, key);
7337
-
7338
- ctx->kv[idx].type = GGUF_TYPE_ARRAY;
7339
- ctx->kv[idx].value.arr.type = type;
7340
- ctx->kv[idx].value.arr.n = n;
7341
- ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
7342
- memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
7343
- }
7344
-
7345
- void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
7346
- const int idx = gguf_get_or_add_key(ctx, key);
7347
-
7348
- ctx->kv[idx].type = GGUF_TYPE_ARRAY;
7349
- ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
7350
- ctx->kv[idx].value.arr.n = n;
7351
- ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
7352
- for (int i = 0; i < n; i++) {
7353
- struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
7354
- str->n = strlen(data[i]);
7355
- str->data = strdup(data[i]);
7356
- }
7357
- }
7358
-
7359
- // set or add KV pairs from another context
7360
- void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
7361
- for (uint32_t i = 0; i < src->header.n_kv; i++) {
7362
- switch (src->kv[i].type) {
7363
- case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break;
7364
- case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break;
7365
- case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break;
7366
- case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break;
7367
- case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break;
7368
- case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break;
7369
- case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break;
7370
- case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break;
7371
- case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break;
7372
- case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break;
7373
- case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break;
7374
- case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
7375
- case GGUF_TYPE_ARRAY:
7376
- {
7377
- if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
7378
- const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
7379
- for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
7380
- data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
7381
- }
7382
- gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
7383
- GGML_FREE((void *)data);
7384
- } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
7385
- GGML_ABORT("nested arrays not supported");
7386
- } else {
7387
- gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
7388
- }
7389
- } break;
7390
- default: GGML_ABORT("invalid type");
7391
- }
7392
- }
7393
- }
7394
-
7395
- void gguf_add_tensor(
7396
- struct gguf_context * ctx,
7397
- const struct ggml_tensor * tensor) {
7398
- GGML_ASSERT(tensor);
7399
- if (gguf_find_tensor(ctx, tensor->name) != -1) {
7400
- GGML_ABORT("duplicated tensor name");
7401
- }
7402
-
7403
- const int idx = ctx->header.n_tensors;
7404
- ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
7405
-
7406
- ctx->infos[idx].name.n = strlen(tensor->name);
7407
- ctx->infos[idx].name.data = strdup(tensor->name);
7408
-
7409
- for (int i = 0; i < GGML_MAX_DIMS; ++i) {
7410
- ctx->infos[idx].ne[i] = 1;
7411
- }
7412
-
7413
- ctx->infos[idx].n_dims = ggml_n_dims(tensor);
7414
- for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
7415
- ctx->infos[idx].ne[i] = tensor->ne[i];
7416
- }
7417
-
7418
- ctx->infos[idx].type = tensor->type;
7419
- ctx->infos[idx].offset = 0;
7420
- ctx->infos[idx].data = tensor->data;
7421
- ctx->infos[idx].size = ggml_nbytes(tensor);
7422
-
7423
- if (ctx->header.n_tensors > 0) {
7424
- ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
7425
- }
7426
-
7427
- ctx->header.n_tensors++;
7428
- }
7429
-
7430
- void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
7431
- const int idx = gguf_find_tensor(ctx, name);
7432
- if (idx < 0) {
7433
- GGML_ABORT("tensor not found");
7434
- }
7435
-
7436
- ctx->infos[idx].type = type;
7437
- }
7438
-
7439
- void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
7440
- const int idx = gguf_find_tensor(ctx, name);
7441
- if (idx < 0) {
7442
- GGML_ABORT("tensor not found");
7443
- }
7444
-
7445
- ctx->infos[idx].data = data;
7446
- ctx->infos[idx].size = size;
7447
-
7448
- // update offsets
7449
- for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
7450
- ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
7451
- }
7452
- }
7453
-
7454
- //static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
7455
- // fwrite(&val->n, sizeof(val->n), 1, file);
7456
- // fwrite(val->data, sizeof(char), val->n, file);
7457
- //}
7458
- //
7459
- //static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
7460
- // fwrite(val, sizeof(char), size, file);
7461
- //}
7462
-
7463
- struct gguf_buf {
7464
- void * data;
7465
- size_t size;
7466
- size_t offset;
7467
- };
7468
-
7469
- static struct gguf_buf gguf_buf_init(size_t size) {
7470
- struct gguf_buf buf = {
7471
- /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
7472
- /*buf.size =*/ size,
7473
- /*buf.offset =*/ 0,
7474
- };
7475
-
7476
- return buf;
7477
- }
7478
-
7479
- static void gguf_buf_free(struct gguf_buf buf) {
7480
- if (buf.data) {
7481
- GGML_FREE(buf.data);
7482
- }
7483
- }
7484
-
7485
- static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
7486
- if (buf->offset + size > buf->size) {
7487
- buf->size = 1.5*(buf->offset + size);
7488
- if (buf->data) {
7489
- buf->data = realloc(buf->data, buf->size);
7490
- }
7491
- }
7492
- }
7493
-
7494
- static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
7495
- gguf_buf_grow(buf, sizeof(val->n) + val->n);
7496
-
7497
- if (buf->data) {
7498
- memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
7499
- }
7500
- buf->offset += sizeof(val->n);
7501
-
7502
- if (buf->data) {
7503
- memcpy((char *) buf->data + buf->offset, val->data, val->n);
7504
- }
7505
- buf->offset += val->n;
7506
- }
7507
-
7508
- static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
7509
- gguf_buf_grow(buf, el_size);
7510
-
7511
- if (buf->data) {
7512
- memcpy((char *) buf->data + buf->offset, val, el_size);
7513
- }
7514
- buf->offset += el_size;
7515
- }
7516
-
7517
- static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
7518
- // write header
7519
- gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
7520
- gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
7521
- gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
7522
- gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv));
7523
-
7524
- // write key-value pairs
7525
- for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
7526
- struct gguf_kv * kv = &ctx->kv[i];
7527
-
7528
- gguf_bwrite_str(buf, &kv->key);
7529
- gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
7530
-
7531
- switch (kv->type) {
7532
- case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break;
7533
- case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break;
7534
- case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break;
7535
- case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break;
7536
- case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break;
7537
- case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break;
7538
- case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
7539
- case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break;
7540
- case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break;
7541
- case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
7542
- case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break;
7543
- case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break;
7544
- case GGUF_TYPE_ARRAY:
7545
- {
7546
- gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
7547
- gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) );
7548
-
7549
- switch (kv->value.arr.type) {
7550
- case GGUF_TYPE_UINT8:
7551
- case GGUF_TYPE_INT8:
7552
- case GGUF_TYPE_UINT16:
7553
- case GGUF_TYPE_INT16:
7554
- case GGUF_TYPE_UINT32:
7555
- case GGUF_TYPE_INT32:
7556
- case GGUF_TYPE_FLOAT32:
7557
- case GGUF_TYPE_UINT64:
7558
- case GGUF_TYPE_INT64:
7559
- case GGUF_TYPE_FLOAT64:
7560
- case GGUF_TYPE_BOOL:
7561
- {
7562
- gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
7563
- } break;
7564
- case GGUF_TYPE_STRING:
7565
- {
7566
- for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
7567
- gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
7568
- }
7569
- } break;
7570
- case GGUF_TYPE_ARRAY:
7571
- default: GGML_ABORT("invalid type");
7572
- }
7573
- } break;
7574
- default: GGML_ABORT("invalid type");
7575
- }
7576
- }
7577
-
7578
- // write tensor infos
7579
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
7580
- struct gguf_tensor_info * info = &ctx->infos[i];
7581
-
7582
- gguf_bwrite_str(buf, &info->name);
7583
- gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
7584
- for (uint32_t j = 0; j < info->n_dims; ++j) {
7585
- gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
7586
- }
7587
- gguf_bwrite_el(buf, &info->type, sizeof(info->type));
7588
- gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
7589
- }
7590
-
7591
- // we require the data section to be aligned, so take into account any padding
7592
- {
7593
- const size_t offset = buf->offset;
7594
- const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
7595
-
7596
- if (offset_pad != offset) {
7597
- uint8_t pad = 0;
7598
- for (size_t i = 0; i < offset_pad - offset; ++i) {
7599
- gguf_bwrite_el(buf, &pad, sizeof(pad));
7600
- }
7601
- }
7602
- }
7603
-
7604
- if (only_meta) {
7605
- return;
7606
- }
7607
-
7608
- size_t offset = 0;
7609
-
7610
- // write tensor data
7611
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
7612
- struct gguf_tensor_info * info = &ctx->infos[i];
7613
-
7614
- const size_t size = info->size;
7615
- const size_t size_pad = GGML_PAD(size, ctx->alignment);
7616
-
7617
- gguf_bwrite_el(buf, info->data, size);
7618
-
7619
- if (size_pad != size) {
7620
- uint8_t pad = 0;
7621
- for (size_t j = 0; j < size_pad - size; ++j) {
7622
- gguf_bwrite_el(buf, &pad, sizeof(pad));
7623
- }
7624
- }
7625
-
7626
- GGML_ASSERT(offset == info->offset);
7627
-
7628
- offset += size_pad;
7629
- }
7630
- }
7631
-
7632
- void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
7633
- FILE * file = ggml_fopen(fname, "wb");
7634
- if (!file) {
7635
- GGML_ABORT("failed to open file for writing");
7636
- }
7637
-
7638
- struct gguf_buf buf = gguf_buf_init(16*1024);
7639
-
7640
- gguf_write_to_buf(ctx, &buf, only_meta);
7641
-
7642
- fwrite(buf.data, 1, buf.offset, file);
7643
-
7644
- gguf_buf_free(buf);
7645
-
7646
- fclose(file);
7647
- }
7648
-
7649
- size_t gguf_get_meta_size(const struct gguf_context * ctx) {
7650
- // no allocs - only compute size
7651
- struct gguf_buf buf = gguf_buf_init(0);
7652
-
7653
- gguf_write_to_buf(ctx, &buf, true);
7654
-
7655
- return buf.offset;
7656
- }
7657
-
7658
- void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
7659
- struct gguf_buf buf = gguf_buf_init(16*1024);
7660
-
7661
- gguf_write_to_buf(ctx, &buf, true);
7662
-
7663
- memcpy(data, buf.data, buf.offset);
7664
-
7665
- gguf_buf_free(buf);
7666
- }
7667
-
7668
6524
  void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
7669
6525
  g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
7670
6526
  g_logger_state.log_callback_user_data = user_data;