whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -1,6 +1,7 @@
1
1
  #include "ggml-rpc.h"
2
2
  #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
+ #include "ggml-cpp.h"
4
5
 
5
6
  #include <cinttypes>
6
7
  #include <string>
@@ -26,15 +27,10 @@
26
27
  # include <unistd.h>
27
28
  #endif
28
29
  #include <cstring>
30
+ #include <fstream>
31
+ #include <filesystem>
29
32
 
30
- #define UNUSED GGML_UNUSED
31
-
32
- #define GGML_DEBUG 0
33
- #if (GGML_DEBUG >= 1)
34
- #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
35
- #else
36
- #define GGML_PRINT_DEBUG(...)
37
- #endif
33
+ namespace fs = std::filesystem;
38
34
 
39
35
  #ifdef _WIN32
40
36
  typedef SOCKET sockfd_t;
@@ -89,13 +85,38 @@ enum rpc_cmd {
89
85
  RPC_CMD_FREE_BUFFER,
90
86
  RPC_CMD_BUFFER_CLEAR,
91
87
  RPC_CMD_SET_TENSOR,
88
+ RPC_CMD_SET_TENSOR_HASH,
92
89
  RPC_CMD_GET_TENSOR,
93
90
  RPC_CMD_COPY_TENSOR,
94
91
  RPC_CMD_GRAPH_COMPUTE,
95
92
  RPC_CMD_GET_DEVICE_MEMORY,
93
+ RPC_CMD_INIT_TENSOR,
94
+ RPC_CMD_GET_ALLOC_SIZE,
95
+ RPC_CMD_HELLO,
96
96
  RPC_CMD_COUNT,
97
97
  };
98
98
 
99
+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
100
+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
101
+
102
+ struct rpc_msg_hello_rsp {
103
+ uint8_t major;
104
+ uint8_t minor;
105
+ uint8_t patch;
106
+ };
107
+
108
+ struct rpc_msg_get_alloc_size_req {
109
+ rpc_tensor tensor;
110
+ };
111
+
112
+ struct rpc_msg_get_alloc_size_rsp {
113
+ uint64_t alloc_size;
114
+ };
115
+
116
+ struct rpc_msg_init_tensor_req {
117
+ rpc_tensor tensor;
118
+ };
119
+
99
120
  struct rpc_msg_alloc_buffer_req {
100
121
  uint64_t size;
101
122
  };
@@ -130,6 +151,16 @@ struct rpc_msg_buffer_clear_req {
130
151
  uint8_t value;
131
152
  };
132
153
 
154
+ struct rpc_msg_set_tensor_hash_req {
155
+ rpc_tensor tensor;
156
+ uint64_t offset;
157
+ uint64_t hash;
158
+ };
159
+
160
+ struct rpc_msg_set_tensor_hash_rsp {
161
+ uint8_t result;
162
+ };
163
+
133
164
  struct rpc_msg_get_tensor_req {
134
165
  rpc_tensor tensor;
135
166
  uint64_t offset;
@@ -176,12 +207,24 @@ struct ggml_backend_rpc_context {
176
207
 
177
208
  struct ggml_backend_rpc_buffer_context {
178
209
  std::shared_ptr<socket_t> sock;
179
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
210
+ void * base_ptr;
180
211
  uint64_t remote_ptr;
181
212
  };
182
213
 
183
214
  // RPC helper functions
184
215
 
216
+ // Computes FNV-1a hash of the data
217
+ static uint64_t fnv_hash(const uint8_t * data, size_t len) {
218
+ const uint64_t fnv_prime = 0x100000001b3ULL;
219
+ uint64_t hash = 0xcbf29ce484222325ULL;
220
+
221
+ for (size_t i = 0; i < len; ++i) {
222
+ hash ^= data[i];
223
+ hash *= fnv_prime;
224
+ }
225
+ return hash;
226
+ }
227
+
185
228
  static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
186
229
  #ifdef _WIN32
187
230
  if (fd == INVALID_SOCKET) {
@@ -341,8 +384,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
341
384
  }
342
385
 
343
386
  // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
344
- // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
345
- static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
387
+ // No response
388
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
346
389
  uint8_t cmd_byte = cmd;
347
390
  if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
348
391
  return false;
@@ -353,6 +396,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
353
396
  if (!send_data(sock->fd, input, input_size)) {
354
397
  return false;
355
398
  }
399
+ return true;
400
+ }
401
+
402
+ // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
403
+ // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
404
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
405
+ if (!send_rpc_cmd(sock, cmd, input, input_size)) {
406
+ return false;
407
+ }
356
408
  // TODO: currently the output_size is always known, do we need support for commands with variable output size?
357
409
  // even if we do, we can skip sending output_size from the server for commands with known output size
358
410
  uint64_t out_size;
@@ -370,6 +422,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
370
422
 
371
423
  // RPC client-side implementation
372
424
 
425
+ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
426
+ rpc_msg_hello_rsp response;
427
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
428
+ GGML_ASSERT(status);
429
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
430
+ fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
431
+ return false;
432
+ }
433
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
434
+ fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
435
+ }
436
+ return true;
437
+ }
438
+
373
439
  static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
374
440
  static std::mutex mutex;
375
441
  std::lock_guard<std::mutex> lock(mutex);
@@ -397,12 +463,15 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
397
463
  initialized = true;
398
464
  }
399
465
  #else
400
- UNUSED(initialized);
466
+ GGML_UNUSED(initialized);
401
467
  #endif
402
468
  auto sock = socket_connect(host.c_str(), port);
403
469
  if (sock == nullptr) {
404
470
  return nullptr;
405
471
  }
472
+ if (!check_server_version(sock)) {
473
+ return nullptr;
474
+ }
406
475
  GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
407
476
  sockets[endpoint] = sock;
408
477
  return sock;
@@ -418,16 +487,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
418
487
 
419
488
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
420
489
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
421
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
422
- return ctx->base_cache[buffer];
490
+ if (ctx->base_ptr != nullptr) {
491
+ return ctx->base_ptr;
423
492
  }
424
493
  rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
425
494
  rpc_msg_buffer_get_base_rsp response;
426
495
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
427
496
  GGML_ASSERT(status);
428
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
429
- ctx->base_cache[buffer] = base_ptr;
430
- return base_ptr;
497
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
498
+ return ctx->base_ptr;
431
499
  }
432
500
 
433
501
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -456,28 +524,55 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
456
524
  result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
457
525
  result.view_offs = tensor->view_offs;
458
526
  result.data = reinterpret_cast<uint64_t>(tensor->data);
527
+
528
+ // Avoid sending uninitialized data over the wire
529
+ memset(result.name, 0, sizeof(result.name));
530
+ memset(result.padding, 0, sizeof(result.padding));
531
+
459
532
  snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
460
533
  return result;
461
534
  }
462
535
 
463
- static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
464
- UNUSED(buffer);
465
- if (ggml_is_quantized(tensor->type)) {
466
- // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
467
- GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
536
+ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
537
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
538
+
539
+ // CUDA backend on the server pads everything to 512 due to CUDA limitations.
540
+ // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
541
+ // In particular, only quantized tensors need padding
542
+ if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
543
+ rpc_msg_init_tensor_req request;
544
+
545
+ request.tensor = serialize_tensor(tensor);
546
+
547
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
548
+ GGML_ASSERT(status);
468
549
  }
550
+ return GGML_STATUS_SUCCESS;
469
551
  }
470
552
 
471
553
  static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
472
554
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
473
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
555
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
556
+ if (size > HASH_THRESHOLD) {
557
+ rpc_msg_set_tensor_hash_req request;
558
+ request.tensor = rpc_tensor;
559
+ request.offset = offset;
560
+ request.hash = fnv_hash((const uint8_t*)data, size);
561
+ rpc_msg_set_tensor_hash_rsp response;
562
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
563
+ GGML_ASSERT(status);
564
+ if (response.result) {
565
+ // the server has the same data, no need to send it
566
+ return;
567
+ }
568
+ }
569
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
474
570
  size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
475
571
  std::vector<uint8_t> input(input_size, 0);
476
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
477
572
  memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
478
573
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
479
574
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
480
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
575
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
481
576
  GGML_ASSERT(status);
482
577
  }
483
578
 
@@ -544,7 +639,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
544
639
  if (response.remote_ptr != 0) {
545
640
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
546
641
  ggml_backend_rpc_buffer_interface,
547
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
642
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
548
643
  response.remote_size);
549
644
  return buffer;
550
645
  } else {
@@ -577,8 +672,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
577
672
  }
578
673
 
579
674
  static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
580
- UNUSED(buft);
581
- return ggml_nbytes(tensor);
675
+ // See comments in init_tensor.
676
+ if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
677
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
678
+ auto sock = get_socket(buft_ctx->endpoint);
679
+
680
+ rpc_msg_get_alloc_size_req request;
681
+
682
+ request.tensor = serialize_tensor(tensor);
683
+
684
+ rpc_msg_get_alloc_size_rsp response;
685
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
686
+ GGML_ASSERT(status);
687
+
688
+ return response.alloc_size;
689
+ } else {
690
+ return ggml_nbytes(tensor);
691
+ }
582
692
  }
583
693
 
584
694
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -603,7 +713,7 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
603
713
  }
604
714
 
605
715
  static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
606
- UNUSED(backend);
716
+ GGML_UNUSED(backend);
607
717
  // this is no-op because we don't have any async operations
608
718
  }
609
719
 
@@ -744,9 +854,12 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
744
854
 
745
855
  class rpc_server {
746
856
  public:
747
- rpc_server(ggml_backend_t backend) : backend(backend) {}
857
+ rpc_server(ggml_backend_t backend, const char * cache_dir)
858
+ : backend(backend), cache_dir(cache_dir) {
859
+ }
748
860
  ~rpc_server();
749
861
 
862
+ void hello(rpc_msg_hello_rsp & response);
750
863
  void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
751
864
  void get_alignment(rpc_msg_get_alignment_rsp & response);
752
865
  void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -754,11 +867,15 @@ public:
754
867
  bool free_buffer(const rpc_msg_free_buffer_req & request);
755
868
  bool buffer_clear(const rpc_msg_buffer_clear_req & request);
756
869
  bool set_tensor(const std::vector<uint8_t> & input);
870
+ bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
757
871
  bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
758
872
  bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759
873
  bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
874
+ bool init_tensor(const rpc_msg_init_tensor_req & request);
875
+ bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
760
876
 
761
877
  private:
878
+ bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
762
879
  ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
763
880
  ggml_tensor * create_node(uint64_t id,
764
881
  struct ggml_context * ctx,
@@ -767,9 +884,47 @@ private:
767
884
 
768
885
 
769
886
  ggml_backend_t backend;
887
+ const char * cache_dir;
770
888
  std::unordered_set<ggml_backend_buffer_t> buffers;
771
889
  };
772
890
 
891
+ void rpc_server::hello(rpc_msg_hello_rsp & response) {
892
+ response.major = RPC_PROTO_MAJOR_VERSION;
893
+ response.minor = RPC_PROTO_MINOR_VERSION;
894
+ response.patch = RPC_PROTO_PATCH_VERSION;
895
+ GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
896
+ }
897
+
898
+ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
899
+ ggml_backend_buffer_type_t buft;
900
+ struct ggml_init_params params {
901
+ /*.mem_size =*/ ggml_tensor_overhead(),
902
+ /*.mem_buffer =*/ NULL,
903
+ /*.no_alloc =*/ true,
904
+ };
905
+
906
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
907
+ GGML_ASSERT(ctx_ptr != nullptr);
908
+ ggml_context * ctx = ctx_ptr.get();
909
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
910
+
911
+ if (tensor == nullptr) {
912
+ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
913
+ return false;
914
+ }
915
+
916
+ if (tensor->buffer == nullptr) {
917
+ //No buffer allocated.
918
+ buft = ggml_backend_get_default_buffer_type(backend);
919
+ } else {
920
+ buft = tensor->buffer->buft;
921
+ }
922
+
923
+ response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
924
+
925
+ return true;
926
+ }
927
+
773
928
  void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
774
929
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
775
930
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
@@ -781,7 +936,7 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
781
936
  GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
782
937
  buffers.insert(buffer);
783
938
  } else {
784
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
939
+ GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
785
940
  }
786
941
  }
787
942
 
@@ -803,7 +958,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp
803
958
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
804
959
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
805
960
  if (buffers.find(buffer) == buffers.end()) {
806
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
961
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
807
962
  return false;
808
963
  }
809
964
  void * base = ggml_backend_buffer_get_base(buffer);
@@ -815,7 +970,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
815
970
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
816
971
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
817
972
  if (buffers.find(buffer) == buffers.end()) {
818
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
973
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
819
974
  return false;
820
975
  }
821
976
  ggml_backend_buffer_free(buffer);
@@ -827,7 +982,7 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
827
982
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
828
983
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
829
984
  if (buffers.find(buffer) == buffers.end()) {
830
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
985
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
831
986
  return false;
832
987
  }
833
988
  ggml_backend_buffer_clear(buffer, request.value);
@@ -835,8 +990,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
835
990
  }
836
991
 
837
992
  ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
993
+ // Validate tensor type before using it
994
+ if (tensor->type >= GGML_TYPE_COUNT) {
995
+ GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
996
+ return nullptr;
997
+ }
998
+
838
999
  ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
839
1000
  tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1001
+
1002
+ // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
1003
+ if (result == nullptr) {
1004
+ GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
1005
+ return nullptr;
1006
+ }
1007
+
840
1008
  for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
841
1009
  result->nb[i] = tensor->nb[i];
842
1010
  }
@@ -880,11 +1048,12 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
880
1048
  /*.mem_buffer =*/ NULL,
881
1049
  /*.no_alloc =*/ true,
882
1050
  };
883
- struct ggml_context * ctx = ggml_init(params);
1051
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1052
+ GGML_ASSERT(ctx_ptr != nullptr);
1053
+ ggml_context * ctx = ctx_ptr.get();
884
1054
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
885
1055
  if (tensor == nullptr) {
886
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
887
- ggml_free(ctx);
1056
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
888
1057
  return false;
889
1058
  }
890
1059
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
@@ -895,13 +1064,118 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
895
1064
  const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
896
1065
 
897
1066
  if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
898
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1067
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1068
+ __func__, in_tensor->data, offset, size, p0, p1);
1069
+ return false;
899
1070
  }
900
1071
  }
901
1072
 
902
1073
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1074
+ if (cache_dir && size > HASH_THRESHOLD) {
1075
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
1076
+ char hash_str[17];
1077
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1078
+ // save to cache_dir/hash_str
1079
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1080
+ std::ofstream ofs(cache_file, std::ios::binary);
1081
+ ofs.write((const char *)data, size);
1082
+ printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1083
+ }
903
1084
  ggml_backend_tensor_set(tensor, data, offset, size);
904
- ggml_free(ctx);
1085
+ return true;
1086
+ }
1087
+
1088
+ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1089
+ if (!cache_dir) {
1090
+ return false;
1091
+ }
1092
+ char hash_str[17];
1093
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1094
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1095
+ if (!fs::exists(cache_file)) {
1096
+ return false;
1097
+ }
1098
+ std::ifstream ifs(cache_file, std::ios::binary);
1099
+ ifs.seekg(0, std::ios::end);
1100
+ size_t size = ifs.tellg();
1101
+ ifs.seekg(0, std::ios::beg);
1102
+ data.resize(size);
1103
+ ifs.read((char *)data.data(), size);
1104
+ return true;
1105
+ }
1106
+
1107
+ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
1108
+ {
1109
+ std::vector<uint8_t> cached_file;
1110
+ if (!get_cached_file(request.hash, cached_file)) {
1111
+ response.result = 0;
1112
+ return true;
1113
+ }
1114
+ size_t size = cached_file.size();
1115
+ struct ggml_init_params params {
1116
+ /*.mem_size =*/ ggml_tensor_overhead(),
1117
+ /*.mem_buffer =*/ NULL,
1118
+ /*.no_alloc =*/ true,
1119
+ };
1120
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1121
+ GGML_ASSERT(ctx_ptr != nullptr);
1122
+ ggml_context * ctx = ctx_ptr.get();
1123
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1124
+ if (tensor == nullptr) {
1125
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1126
+ return false;
1127
+ }
1128
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1129
+ __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1130
+
1131
+ // sanitize tensor->data
1132
+ {
1133
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1134
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1135
+
1136
+ if (request.tensor.data + request.offset < p0
1137
+ || request.tensor.data + request.offset >= p1
1138
+ || size > (p1 - request.tensor.data - request.offset)) {
1139
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1140
+ __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
1141
+ return false;
1142
+ }
1143
+ }
1144
+ ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
1145
+ response.result = 1;
1146
+ return true;
1147
+ }
1148
+
1149
+ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
1150
+ struct ggml_init_params params {
1151
+ /*.mem_size =*/ ggml_tensor_overhead(),
1152
+ /*.mem_buffer =*/ NULL,
1153
+ /*.no_alloc =*/ true,
1154
+ };
1155
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1156
+ GGML_ASSERT(ctx_ptr != nullptr);
1157
+ ggml_context * ctx = ctx_ptr.get();
1158
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1159
+ if (tensor == nullptr) {
1160
+ GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
1161
+ return false;
1162
+ }
1163
+
1164
+ // Call the backend's buffer_init_tensor function
1165
+ ggml_backend_buffer_t buffer = tensor->buffer;
1166
+ if (buffer && buffer->iface.init_tensor) {
1167
+ buffer->iface.init_tensor(buffer, tensor);
1168
+ } else {
1169
+ GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
1170
+ }
1171
+
1172
+ if (tensor->extra != nullptr) {
1173
+ // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1174
+ // Currently unimplemented.
1175
+ GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
1176
+ return false;
1177
+ }
1178
+
905
1179
  return true;
906
1180
  }
907
1181
 
@@ -911,11 +1185,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
911
1185
  /*.mem_buffer =*/ NULL,
912
1186
  /*.no_alloc =*/ true,
913
1187
  };
914
- struct ggml_context * ctx = ggml_init(params);
1188
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1189
+ GGML_ASSERT(ctx_ptr != nullptr);
1190
+ ggml_context * ctx = ctx_ptr.get();
915
1191
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
916
1192
  if (tensor == nullptr) {
917
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
918
- ggml_free(ctx);
1193
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
919
1194
  return false;
920
1195
  }
921
1196
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
@@ -928,13 +1203,14 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
928
1203
  if (request.tensor.data + request.offset < p0 ||
929
1204
  request.tensor.data + request.offset >= p1 ||
930
1205
  request.size > (p1 - request.tensor.data - request.offset)) {
931
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1206
+ GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1207
+ __func__, request.tensor.data, request.offset, request.size, p0, p1);
1208
+ return false;
932
1209
  }
933
1210
  }
934
1211
 
935
1212
  response.resize(request.size, 0);
936
1213
  ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
937
- ggml_free(ctx);
938
1214
  return true;
939
1215
  }
940
1216
 
@@ -944,17 +1220,38 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
944
1220
  /*.mem_buffer =*/ NULL,
945
1221
  /*.no_alloc =*/ true,
946
1222
  };
947
- struct ggml_context * ctx = ggml_init(params);
1223
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1224
+ GGML_ASSERT(ctx_ptr != nullptr);
1225
+ ggml_context * ctx = ctx_ptr.get();
1226
+
948
1227
  ggml_tensor * src = deserialize_tensor(ctx, &request.src);
949
1228
  ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
950
1229
  if (src == nullptr || dst == nullptr) {
951
- GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
952
- ggml_free(ctx);
1230
+ GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1231
+ return false;
1232
+ }
1233
+
1234
+ uint64_t src_size = (uint64_t) ggml_nbytes(src);
1235
+ uint64_t dst_data = (uint64_t) dst->data;
1236
+ uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
1237
+ uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1238
+
1239
+ if (dst_data + src_size > dst_base + dst_buf_sz) {
1240
+ GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1241
+ " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1242
+ " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1243
+ __func__,
1244
+ dst_data,
1245
+ dst_data + src_size,
1246
+ dst_base,
1247
+ dst_base + dst_buf_sz);
953
1248
  return false;
954
1249
  }
955
- GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
1250
+
1251
+ GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
1252
+ __func__, (void*) src->buffer, (void*) dst->buffer);
1253
+
956
1254
  response.result = ggml_backend_buffer_copy_tensor(src, dst);
957
- ggml_free(ctx);
958
1255
  return true;
959
1256
  }
960
1257
 
@@ -962,22 +1259,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
962
1259
  struct ggml_context * ctx,
963
1260
  const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
964
1261
  std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
965
- if (id == 0) {
966
- return nullptr;
967
- }
968
1262
  if (tensor_map.find(id) != tensor_map.end()) {
969
1263
  return tensor_map[id];
970
1264
  }
971
- const rpc_tensor * tensor = tensor_ptrs.at(id);
1265
+ // Safely find the tensor pointer
1266
+ auto it_ptr = tensor_ptrs.find(id);
1267
+ if (it_ptr == tensor_ptrs.end()) {
1268
+ return nullptr;
1269
+ }
1270
+ const rpc_tensor * tensor = it_ptr->second;
1271
+
972
1272
  struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
973
1273
  if (result == nullptr) {
974
1274
  return nullptr;
975
1275
  }
976
1276
  tensor_map[id] = result;
977
1277
  for (int i = 0; i < GGML_MAX_SRC; i++) {
978
- result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1278
+ // Check if the source ID is 0 before calling create_node recursively
1279
+ if (tensor->src[i] == 0) {
1280
+ result->src[i] = nullptr;
1281
+ } else {
1282
+ result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1283
+ // If the recursive call failed for a non-zero ID, propagate the error
1284
+ if (result->src[i] == nullptr) {
1285
+ GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1286
+ __func__, i, tensor->src[i], id);
1287
+ // Must return nullptr to signal failure up the call stack
1288
+ return nullptr;
1289
+ }
1290
+ }
1291
+ }
1292
+
1293
+ // Handle view_src similarly
1294
+ if (tensor->view_src == 0) {
1295
+ result->view_src = nullptr;
1296
+ } else {
1297
+ result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1298
+ // If the recursive call failed for a non-zero ID, propagate the error
1299
+ if (result->view_src == nullptr) {
1300
+ GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1301
+ __func__, tensor->view_src, id);
1302
+ // Must return nullptr to signal failure up the call stack
1303
+ return nullptr;
1304
+ }
979
1305
  }
980
- result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
981
1306
  result->view_offs = tensor->view_offs;
982
1307
  return result;
983
1308
  }
@@ -1003,12 +1328,15 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1003
1328
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1004
1329
 
1005
1330
  size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1331
+
1006
1332
  struct ggml_init_params params = {
1007
1333
  /*.mem_size =*/ buf_size,
1008
1334
  /*.mem_buffer =*/ NULL,
1009
1335
  /*.no_alloc =*/ true,
1010
1336
  };
1011
- struct ggml_context * ctx = ggml_init(params);
1337
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1338
+ GGML_ASSERT(ctx_ptr != nullptr);
1339
+ ggml_context * ctx = ctx_ptr.get();
1012
1340
  struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1013
1341
  graph->n_nodes = n_nodes;
1014
1342
  std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
@@ -1020,10 +1348,17 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1020
1348
  int64_t id;
1021
1349
  memcpy(&id, &nodes[i], sizeof(id));
1022
1350
  graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1351
+
1352
+ // Check if create_node failed for a *non-zero* ID.
1353
+ // If id was 0, create_node returning nullptr is expected.
1354
+ // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1355
+ if (graph->nodes[i] == nullptr && id != 0) {
1356
+ GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1357
+ return false;
1358
+ }
1023
1359
  }
1024
1360
  ggml_status status = ggml_backend_graph_compute(backend, graph);
1025
1361
  response.result = status;
1026
- ggml_free(ctx);
1027
1362
  return true;
1028
1363
  }
1029
1364
 
@@ -1033,10 +1368,27 @@ rpc_server::~rpc_server() {
1033
1368
  }
1034
1369
  }
1035
1370
 
1036
- static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1037
- rpc_server server(backend);
1371
+ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1372
+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1373
+ rpc_server server(backend, cache_dir);
1374
+ uint8_t cmd;
1375
+ if (!recv_data(sockfd, &cmd, 1)) {
1376
+ return;
1377
+ }
1378
+ // the first command sent by the client must be HELLO
1379
+ if (cmd != RPC_CMD_HELLO) {
1380
+ fprintf(stderr, "Expected HELLO command, update client\n");
1381
+ return;
1382
+ }
1383
+ if (!recv_msg(sockfd, nullptr, 0)) {
1384
+ return;
1385
+ }
1386
+ rpc_msg_hello_rsp response;
1387
+ server.hello(response);
1388
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1389
+ return;
1390
+ }
1038
1391
  while (true) {
1039
- uint8_t cmd;
1040
1392
  if (!recv_data(sockfd, &cmd, 1)) {
1041
1393
  break;
1042
1394
  }
@@ -1046,6 +1398,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1046
1398
  break;
1047
1399
  }
1048
1400
  switch (cmd) {
1401
+ case RPC_CMD_HELLO: {
1402
+ // HELLO command is handled above
1403
+ return;
1404
+ }
1049
1405
  case RPC_CMD_ALLOC_BUFFER: {
1050
1406
  rpc_msg_alloc_buffer_req request;
1051
1407
  if (!recv_msg(sockfd, &request, sizeof(request))) {
@@ -1058,6 +1414,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1058
1414
  }
1059
1415
  break;
1060
1416
  }
1417
+ case RPC_CMD_GET_ALLOC_SIZE: {
1418
+ rpc_msg_get_alloc_size_req request;
1419
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1420
+ return;
1421
+ }
1422
+ rpc_msg_get_alloc_size_rsp response;
1423
+ if (!server.get_alloc_size(request, response)) {
1424
+ return;
1425
+ }
1426
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1427
+ return;
1428
+ }
1429
+ break;
1430
+ }
1061
1431
  case RPC_CMD_GET_ALIGNMENT: {
1062
1432
  if (!recv_msg(sockfd, nullptr, 0)) {
1063
1433
  return;
@@ -1128,6 +1498,30 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1128
1498
  if (!server.set_tensor(input)) {
1129
1499
  return;
1130
1500
  }
1501
+ break;
1502
+ }
1503
+ case RPC_CMD_SET_TENSOR_HASH: {
1504
+ rpc_msg_set_tensor_hash_req request;
1505
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1506
+ return;
1507
+ }
1508
+ rpc_msg_set_tensor_hash_rsp response;
1509
+ if (!server.set_tensor_hash(request, response)) {
1510
+ return;
1511
+ }
1512
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1513
+ return;
1514
+ }
1515
+ break;
1516
+ }
1517
+ case RPC_CMD_INIT_TENSOR: {
1518
+ rpc_msg_init_tensor_req request;
1519
+ if (!recv_msg(sockfd, &request,sizeof(request))) {
1520
+ return;
1521
+ }
1522
+ if (!server.init_tensor(request)) {
1523
+ return;
1524
+ }
1131
1525
  if (!send_msg(sockfd, nullptr, 0)) {
1132
1526
  return;
1133
1527
  }
@@ -1195,7 +1589,17 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1195
1589
  }
1196
1590
  }
1197
1591
 
1198
- void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1592
+ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
1593
+ const char * cache_dir,
1594
+ size_t free_mem, size_t total_mem) {
1595
+ printf("Starting RPC server v%d.%d.%d\n",
1596
+ RPC_PROTO_MAJOR_VERSION,
1597
+ RPC_PROTO_MINOR_VERSION,
1598
+ RPC_PROTO_PATCH_VERSION);
1599
+ printf(" endpoint : %s\n", endpoint);
1600
+ printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
1601
+ printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
1602
+
1199
1603
  std::string host;
1200
1604
  int port;
1201
1605
  if (!parse_endpoint(endpoint, host, port)) {
@@ -1224,7 +1628,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
1224
1628
  }
1225
1629
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1226
1630
  fflush(stdout);
1227
- rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1631
+ rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
1228
1632
  printf("Client connection closed\n");
1229
1633
  fflush(stdout);
1230
1634
  }
@@ -1257,14 +1661,14 @@ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t *
1257
1661
 
1258
1662
  ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1259
1663
 
1260
- UNUSED(dev);
1664
+ GGML_UNUSED(dev);
1261
1665
  }
1262
1666
 
1263
1667
  static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1264
1668
  // TODO: obtain value from the server
1265
1669
  return GGML_BACKEND_DEVICE_TYPE_GPU;
1266
1670
 
1267
- UNUSED(dev);
1671
+ GGML_UNUSED(dev);
1268
1672
  }
1269
1673
 
1270
1674
  static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
@@ -1285,7 +1689,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
1285
1689
 
1286
1690
  return ggml_backend_rpc_init(ctx->endpoint.c_str());
1287
1691
 
1288
- UNUSED(params);
1692
+ GGML_UNUSED(params);
1289
1693
  }
1290
1694
 
1291
1695
  static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -1293,12 +1697,12 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b
1293
1697
 
1294
1698
  return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1295
1699
 
1296
- UNUSED(dev);
1700
+ GGML_UNUSED(dev);
1297
1701
  }
1298
1702
 
1299
1703
  static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1300
- UNUSED(dev);
1301
- UNUSED(op);
1704
+ GGML_UNUSED(dev);
1705
+ GGML_UNUSED(op);
1302
1706
  //TODO: call the remote backend and cache the results
1303
1707
  return true;
1304
1708
  }
@@ -1335,29 +1739,32 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1335
1739
  static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1336
1740
  return "RPC";
1337
1741
 
1338
- UNUSED(reg);
1742
+ GGML_UNUSED(reg);
1339
1743
  }
1340
1744
 
1341
1745
  static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1342
1746
  return 0;
1343
1747
 
1344
- UNUSED(reg);
1748
+ GGML_UNUSED(reg);
1345
1749
  }
1346
1750
 
1347
1751
  static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1348
1752
  GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1349
1753
 
1350
- UNUSED(reg);
1351
- UNUSED(index);
1754
+ GGML_UNUSED(reg);
1755
+ GGML_UNUSED(index);
1352
1756
  }
1353
1757
 
1354
1758
  static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1355
1759
  if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1356
1760
  return (void *)ggml_backend_rpc_add_device;
1357
1761
  }
1762
+ if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
1763
+ return (void *)ggml_backend_rpc_start_server;
1764
+ }
1358
1765
  return NULL;
1359
1766
 
1360
- UNUSED(reg);
1767
+ GGML_UNUSED(reg);
1361
1768
  }
1362
1769
 
1363
1770
  static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {