whispercpp 1.3.1 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (857) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +7 -3
  3. data/README.md +161 -43
  4. data/Rakefile +45 -13
  5. data/ext/.gitignore +4 -8
  6. data/ext/dependencies.rb +73 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +85 -0
  9. data/ext/ruby_whisper.c +177 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +672 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1303 -0
  15. data/ext/ruby_whisper_segment.c +220 -0
  16. data/ext/ruby_whisper_transcribe.cpp +93 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  19. data/ext/sources/CMakeLists.txt +255 -0
  20. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  21. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  22. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  23. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  24. data/ext/sources/bindings/javascript/package.json +26 -0
  25. data/ext/sources/bindings/javascript/whisper.js +19 -0
  26. data/ext/sources/build-xcframework.sh +547 -0
  27. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  28. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  29. data/ext/sources/cmake/build-info.cmake +60 -0
  30. data/ext/sources/cmake/git-vars.cmake +22 -0
  31. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  32. data/ext/sources/cmake/whisper.pc.in +10 -0
  33. data/ext/sources/examples/CMakeLists.txt +124 -0
  34. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  35. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +133 -0
  36. data/ext/sources/examples/addon.node/addon.cpp +557 -0
  37. data/ext/sources/examples/addon.node/index.js +57 -0
  38. data/ext/sources/examples/addon.node/package.json +16 -0
  39. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  40. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  41. data/ext/sources/examples/bench/bench.cpp +176 -0
  42. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  43. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  44. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  45. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  46. data/ext/sources/examples/cli/cli.cpp +1295 -0
  47. data/ext/sources/examples/coi-serviceworker.js +146 -0
  48. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  49. data/ext/sources/examples/command/command.cpp +800 -0
  50. data/ext/sources/examples/command/commands.txt +9 -0
  51. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  52. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  53. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  54. data/ext/sources/examples/common-ggml.cpp +238 -0
  55. data/ext/sources/examples/common-ggml.h +18 -0
  56. data/ext/sources/examples/common-sdl.cpp +227 -0
  57. data/ext/sources/examples/common-sdl.h +49 -0
  58. data/ext/sources/examples/common-whisper.cpp +175 -0
  59. data/ext/sources/examples/common-whisper.h +24 -0
  60. data/ext/sources/examples/common.cpp +675 -0
  61. data/ext/sources/examples/common.h +322 -0
  62. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  63. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  64. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  65. data/ext/sources/examples/generate-karaoke.sh +57 -0
  66. data/ext/sources/examples/grammar-parser.cpp +423 -0
  67. data/ext/sources/examples/grammar-parser.h +29 -0
  68. data/ext/sources/examples/helpers.js +191 -0
  69. data/ext/sources/examples/json.hpp +24596 -0
  70. data/ext/sources/examples/livestream.sh +112 -0
  71. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  72. data/ext/sources/examples/lsp/lsp.cpp +469 -0
  73. data/ext/sources/examples/lsp/whisper.vim +362 -0
  74. data/ext/sources/examples/miniaudio.h +93468 -0
  75. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  76. data/ext/sources/examples/python/whisper_processor.py +54 -0
  77. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  78. data/ext/sources/examples/quantize/quantize.cpp +226 -0
  79. data/ext/sources/examples/server/CMakeLists.txt +15 -0
  80. data/ext/sources/examples/server/bench.js +29 -0
  81. data/ext/sources/examples/server/httplib.h +10497 -0
  82. data/ext/sources/examples/server/server.cpp +1238 -0
  83. data/ext/sources/examples/server.py +115 -0
  84. data/ext/sources/examples/stb_vorbis.c +5584 -0
  85. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  86. data/ext/sources/examples/stream/stream.cpp +435 -0
  87. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  88. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  89. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  90. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  91. data/ext/sources/examples/sycl/build.sh +22 -0
  92. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  93. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  94. data/ext/sources/examples/talk-llama/CMakeLists.txt +43 -0
  95. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  96. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  97. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  98. data/ext/sources/examples/talk-llama/llama-arch.cpp +1914 -0
  99. data/ext/sources/examples/talk-llama/llama-arch.h +464 -0
  100. data/ext/sources/examples/talk-llama/llama-batch.cpp +843 -0
  101. data/ext/sources/examples/talk-llama/llama-batch.h +147 -0
  102. data/ext/sources/examples/talk-llama/llama-chat.cpp +685 -0
  103. data/ext/sources/examples/talk-llama/llama-chat.h +59 -0
  104. data/ext/sources/examples/talk-llama/llama-context.cpp +2845 -0
  105. data/ext/sources/examples/talk-llama/llama-context.h +297 -0
  106. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  107. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  108. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  109. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  110. data/ext/sources/examples/talk-llama/llama-graph.cpp +1693 -0
  111. data/ext/sources/examples/talk-llama/llama-graph.h +710 -0
  112. data/ext/sources/examples/talk-llama/llama-hparams.cpp +103 -0
  113. data/ext/sources/examples/talk-llama/llama-hparams.h +207 -0
  114. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  115. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  116. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  117. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  118. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  119. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  120. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +44 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +439 -0
  124. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  125. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  126. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  127. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  128. data/ext/sources/examples/talk-llama/llama-memory.cpp +59 -0
  129. data/ext/sources/examples/talk-llama/llama-memory.h +116 -0
  130. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  131. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  132. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1163 -0
  133. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  134. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +282 -0
  135. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  136. data/ext/sources/examples/talk-llama/llama-model.cpp +15114 -0
  137. data/ext/sources/examples/talk-llama/llama-model.h +452 -0
  138. data/ext/sources/examples/talk-llama/llama-quant.cpp +1049 -0
  139. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  140. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  141. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  142. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3377 -0
  143. data/ext/sources/examples/talk-llama/llama-vocab.h +132 -0
  144. data/ext/sources/examples/talk-llama/llama.cpp +358 -0
  145. data/ext/sources/examples/talk-llama/llama.h +1484 -0
  146. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  147. data/ext/sources/examples/talk-llama/speak +40 -0
  148. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  149. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  150. data/ext/sources/examples/talk-llama/talk-llama.cpp +810 -0
  151. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  152. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  153. data/ext/sources/examples/talk-llama/unicode.cpp +854 -0
  154. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  155. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  156. data/ext/sources/examples/vad-speech-segments/speech.cpp +149 -0
  157. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  158. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  159. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  160. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  161. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  162. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  163. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  164. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  165. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +251 -0
  166. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  167. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  168. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  169. data/ext/sources/ggml/CMakeLists.txt +435 -0
  170. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  171. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  172. data/ext/sources/ggml/cmake/common.cmake +50 -0
  173. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  174. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +10 -8
  176. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +11 -1
  178. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  179. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  180. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  181. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  182. data/ext/{ggml → sources/ggml}/include/ggml.h +325 -269
  183. data/ext/sources/ggml/include/gguf.h +202 -0
  184. data/ext/sources/ggml/src/CMakeLists.txt +404 -0
  185. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  186. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  187. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  188. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +92 -53
  189. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +69 -34
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  191. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +75 -0
  192. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  195. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  196. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  197. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +140 -1
  198. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +588 -146
  199. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  200. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  201. data/ext/{ggml → sources/ggml}/src/ggml-common.h +16 -8
  202. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +597 -0
  203. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +3 -2
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +11 -10
  205. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  208. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  209. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  210. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  211. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  212. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  213. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  214. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  215. data/ext/{ggml/src/ggml-cpu/cpu-feats-x86.cpp → sources/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp} +5 -1
  216. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  217. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +3285 -0
  218. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  219. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  220. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  221. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  222. data/ext/sources/ggml/src/ggml-cpu/common.h +73 -0
  223. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +172 -41
  224. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3551 -0
  225. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +78 -25
  226. data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.cpp → sources/ggml/src/ggml-cpu/hbm.cpp} +1 -1
  227. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  228. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  229. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  230. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  231. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3594 -0
  232. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +19 -0
  233. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +9786 -0
  234. data/ext/sources/ggml/src/ggml-cpu/ops.h +118 -0
  235. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  236. data/ext/{ggml/src/ggml-cpu/ggml-cpu-quants.h → sources/ggml/src/ggml-cpu/quants.h} +26 -0
  237. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  238. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  239. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +1184 -0
  240. data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.cpp → sources/ggml/src/ggml-cpu/traits.cpp} +1 -1
  241. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  242. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  243. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +345 -0
  244. data/ext/sources/ggml/src/ggml-cpu/vec.h +1027 -0
  245. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  246. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  247. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  248. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  249. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  250. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  251. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  252. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  253. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  254. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  255. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  256. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  257. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/common.cuh +851 -0
  259. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  260. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  262. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  264. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  266. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  267. data/ext/sources/ggml/src/ggml-cuda/convert.cu +752 -0
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +31 -0
  269. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  270. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  271. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  273. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  275. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  276. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  277. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  278. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1474 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  287. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +638 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  289. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  290. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  291. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  292. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  293. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3647 -0
  294. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  295. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  296. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  297. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  298. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  299. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  300. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  301. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  302. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +506 -0
  304. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +11 -0
  305. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  307. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  308. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  309. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  310. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  312. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  313. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  314. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  315. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  316. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  317. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  318. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  319. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  320. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  321. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  322. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  323. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  324. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  325. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  326. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  327. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +155 -0
  328. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  329. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  330. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +26 -0
  332. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +4 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  430. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  432. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  433. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  434. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  436. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  437. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  438. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  439. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  440. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  441. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  442. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  443. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  444. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  445. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  446. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  447. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  448. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  449. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  450. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  451. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  452. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  453. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  454. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  455. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  456. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  457. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  458. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  459. data/ext/sources/ggml/src/ggml-cuda/unary.cu +378 -0
  460. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +66 -0
  461. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  462. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  463. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  464. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  465. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  466. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  467. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  468. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  469. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +135 -0
  470. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +147 -158
  471. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  481. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  482. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  483. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  484. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  485. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  486. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  487. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  488. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  489. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  490. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  491. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  492. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  493. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  494. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  495. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  496. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  497. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  498. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  499. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  500. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  501. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  502. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  503. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  504. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  505. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  506. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  507. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  508. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  509. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +121 -0
  510. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +649 -0
  511. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2504 -1108
  512. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +2102 -1463
  513. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  514. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  515. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  516. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +110 -0
  517. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +6494 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  526. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  527. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  528. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  529. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  530. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  531. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  532. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  533. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  534. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  535. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  536. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  537. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  538. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  539. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  540. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  541. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  542. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  543. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  544. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  545. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  546. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  547. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  548. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  549. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  550. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  551. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  552. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  553. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  554. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  555. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  556. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  557. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  558. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  559. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  560. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  561. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  562. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  563. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  564. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  565. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  566. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  567. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  568. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  569. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +120 -128
  570. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  571. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +494 -84
  572. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  573. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  574. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +344 -0
  575. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  576. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  577. data/ext/sources/ggml/src/ggml-sycl/common.hpp +561 -0
  578. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +56 -70
  579. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  580. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +8 -12
  581. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  582. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +575 -0
  583. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  584. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +839 -0
  585. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  586. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +823 -0
  587. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +188 -67
  588. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  589. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2987 -0
  590. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1120 -0
  591. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +84 -0
  592. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +102 -0
  593. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +212 -0
  594. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  595. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1197 -1295
  596. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  597. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  598. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  599. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  600. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +60 -81
  601. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  602. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1065 -0
  603. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  604. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +482 -0
  605. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  606. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  607. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  608. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  609. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +111 -0
  610. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +472 -0
  611. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  612. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +38 -28
  613. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  614. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +15 -0
  615. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +26 -0
  616. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +6 -11
  617. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  618. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1307 -0
  619. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +289 -0
  620. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +200 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  623. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3822 -1335
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +31 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +61 -0
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  740. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +203 -36
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  743. data/ext/{ggml → sources/ggml}/src/ggml.c +918 -1782
  744. data/ext/sources/ggml/src/ggml.cpp +26 -0
  745. data/ext/sources/ggml/src/gguf.cpp +1351 -0
  746. data/ext/{include → sources/include}/whisper.h +70 -2
  747. data/ext/sources/src/CMakeLists.txt +145 -0
  748. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  749. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  750. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  751. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +36 -10
  752. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  753. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +29 -3
  754. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  755. data/ext/sources/src/whisper-arch.h +197 -0
  756. data/ext/{src → sources/src}/whisper.cpp +1966 -386
  757. data/ext/sources/tests/CMakeLists.txt +105 -0
  758. data/ext/sources/tests/earnings21/eval.mk +58 -0
  759. data/ext/sources/tests/earnings21/eval.py +68 -0
  760. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  761. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  762. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  763. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  764. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  765. data/ext/sources/tests/en-0-ref.txt +1 -0
  766. data/ext/sources/tests/en-1-ref.txt +1 -0
  767. data/ext/sources/tests/en-2-ref.txt +1 -0
  768. data/ext/sources/tests/es-0-ref.txt +1 -0
  769. data/ext/sources/tests/librispeech/eval.mk +39 -0
  770. data/ext/sources/tests/librispeech/eval.py +47 -0
  771. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  772. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  773. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  774. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  775. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  776. data/ext/sources/tests/run-tests.sh +130 -0
  777. data/ext/sources/tests/test-c.c +3 -0
  778. data/ext/sources/tests/test-vad-full.cpp +54 -0
  779. data/ext/sources/tests/test-vad.cpp +83 -0
  780. data/ext/sources/tests/test-whisper.js +58 -0
  781. data/extsources.rb +39 -5
  782. data/lib/whisper/context.rb +15 -0
  783. data/lib/whisper/model/uri.rb +202 -126
  784. data/lib/whisper/segment.rb +58 -0
  785. data/sig/whisper.rbs +510 -0
  786. data/test/helper.rb +24 -0
  787. data/{tests → test}/test_callback.rb +45 -3
  788. data/{tests → test}/test_error.rb +2 -2
  789. data/{tests → test}/test_model.rb +47 -0
  790. data/test/test_package.rb +51 -0
  791. data/test/test_params.rb +297 -0
  792. data/test/test_segment.rb +146 -0
  793. data/test/test_vad.rb +19 -0
  794. data/test/test_vad_params.rb +103 -0
  795. data/{tests → test}/test_whisper.rb +106 -36
  796. data/whispercpp.gemspec +5 -5
  797. metadata +837 -134
  798. data/ext/cpu.mk +0 -9
  799. data/ext/examples/dr_wav.h +0 -8815
  800. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  801. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  802. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  803. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -10835
  804. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  805. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  806. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  807. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  808. data/ext/ggml/src/ggml-sycl/convert.cpp +0 -547
  809. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  810. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  811. data/ext/ggml/src/ggml-sycl/mmvq.cpp +0 -1015
  812. data/ext/ggml/src/ggml-sycl/norm.cpp +0 -378
  813. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  814. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  815. data/ext/metal-embed.mk +0 -17
  816. data/ext/metal.mk +0 -6
  817. data/ext/ruby_whisper.cpp +0 -1909
  818. data/ext/scripts/get-flags.mk +0 -38
  819. data/lib/whisper.rb +0 -2
  820. data/tests/helper.rb +0 -7
  821. data/tests/test_package.rb +0 -31
  822. data/tests/test_params.rb +0 -160
  823. data/tests/test_segment.rb +0 -83
  824. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  825. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  826. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  827. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  828. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  829. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  830. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  831. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  832. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  833. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  834. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  835. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  836. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  837. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  838. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  839. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  840. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  841. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  842. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  843. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  844. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  845. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  846. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.h → sources/ggml/src/ggml-cpu/hbm.h} +0 -0
  847. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.h → sources/ggml/src/ggml-cpu/traits.h} +0 -0
  848. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  849. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  850. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  851. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  852. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  853. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  854. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
  855. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  856. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  857. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
@@ -1,6 +1,7 @@
1
1
  #include "ggml-rpc.h"
2
2
  #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
+ #include "ggml-cpp.h"
4
5
 
5
6
  #include <cinttypes>
6
7
  #include <string>
@@ -26,15 +27,10 @@
26
27
  # include <unistd.h>
27
28
  #endif
28
29
  #include <cstring>
30
+ #include <fstream>
31
+ #include <filesystem>
29
32
 
30
- #define UNUSED GGML_UNUSED
31
-
32
- #define GGML_DEBUG 0
33
- #if (GGML_DEBUG >= 1)
34
- #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
35
- #else
36
- #define GGML_PRINT_DEBUG(...)
37
- #endif
33
+ namespace fs = std::filesystem;
38
34
 
39
35
  #ifdef _WIN32
40
36
  typedef SOCKET sockfd_t;
@@ -57,6 +53,9 @@ struct socket_t {
57
53
  }
58
54
  };
59
55
 
56
+ // macro for nicer error messages on server crash
57
+ #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
58
+
60
59
  // all RPC structures must be packed
61
60
  #pragma pack(push, 1)
62
61
  // ggml_tensor is serialized into rpc_tensor
@@ -89,13 +88,38 @@ enum rpc_cmd {
89
88
  RPC_CMD_FREE_BUFFER,
90
89
  RPC_CMD_BUFFER_CLEAR,
91
90
  RPC_CMD_SET_TENSOR,
91
+ RPC_CMD_SET_TENSOR_HASH,
92
92
  RPC_CMD_GET_TENSOR,
93
93
  RPC_CMD_COPY_TENSOR,
94
94
  RPC_CMD_GRAPH_COMPUTE,
95
95
  RPC_CMD_GET_DEVICE_MEMORY,
96
+ RPC_CMD_INIT_TENSOR,
97
+ RPC_CMD_GET_ALLOC_SIZE,
98
+ RPC_CMD_HELLO,
96
99
  RPC_CMD_COUNT,
97
100
  };
98
101
 
102
+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
103
+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
104
+
105
+ struct rpc_msg_hello_rsp {
106
+ uint8_t major;
107
+ uint8_t minor;
108
+ uint8_t patch;
109
+ };
110
+
111
+ struct rpc_msg_get_alloc_size_req {
112
+ rpc_tensor tensor;
113
+ };
114
+
115
+ struct rpc_msg_get_alloc_size_rsp {
116
+ uint64_t alloc_size;
117
+ };
118
+
119
+ struct rpc_msg_init_tensor_req {
120
+ rpc_tensor tensor;
121
+ };
122
+
99
123
  struct rpc_msg_alloc_buffer_req {
100
124
  uint64_t size;
101
125
  };
@@ -130,6 +154,16 @@ struct rpc_msg_buffer_clear_req {
130
154
  uint8_t value;
131
155
  };
132
156
 
157
+ struct rpc_msg_set_tensor_hash_req {
158
+ rpc_tensor tensor;
159
+ uint64_t offset;
160
+ uint64_t hash;
161
+ };
162
+
163
+ struct rpc_msg_set_tensor_hash_rsp {
164
+ uint8_t result;
165
+ };
166
+
133
167
  struct rpc_msg_get_tensor_req {
134
168
  rpc_tensor tensor;
135
169
  uint64_t offset;
@@ -176,12 +210,24 @@ struct ggml_backend_rpc_context {
176
210
 
177
211
  struct ggml_backend_rpc_buffer_context {
178
212
  std::shared_ptr<socket_t> sock;
179
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
213
+ void * base_ptr;
180
214
  uint64_t remote_ptr;
181
215
  };
182
216
 
183
217
  // RPC helper functions
184
218
 
219
+ // Computes FNV-1a hash of the data
220
+ static uint64_t fnv_hash(const uint8_t * data, size_t len) {
221
+ const uint64_t fnv_prime = 0x100000001b3ULL;
222
+ uint64_t hash = 0xcbf29ce484222325ULL;
223
+
224
+ for (size_t i = 0; i < len; ++i) {
225
+ hash ^= data[i];
226
+ hash *= fnv_prime;
227
+ }
228
+ return hash;
229
+ }
230
+
185
231
  static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
186
232
  #ifdef _WIN32
187
233
  if (fd == INVALID_SOCKET) {
@@ -341,8 +387,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
341
387
  }
342
388
 
343
389
  // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
344
- // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
345
- static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
390
+ // No response
391
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
346
392
  uint8_t cmd_byte = cmd;
347
393
  if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
348
394
  return false;
@@ -353,6 +399,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
353
399
  if (!send_data(sock->fd, input, input_size)) {
354
400
  return false;
355
401
  }
402
+ return true;
403
+ }
404
+
405
+ // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
406
+ // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
407
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
408
+ if (!send_rpc_cmd(sock, cmd, input, input_size)) {
409
+ return false;
410
+ }
356
411
  // TODO: currently the output_size is always known, do we need support for commands with variable output size?
357
412
  // even if we do, we can skip sending output_size from the server for commands with known output size
358
413
  uint64_t out_size;
@@ -370,6 +425,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
370
425
 
371
426
  // RPC client-side implementation
372
427
 
428
+ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
429
+ rpc_msg_hello_rsp response;
430
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
431
+ RPC_STATUS_ASSERT(status);
432
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
433
+ fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
434
+ return false;
435
+ }
436
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
437
+ fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
438
+ }
439
+ return true;
440
+ }
441
+
373
442
  static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
374
443
  static std::mutex mutex;
375
444
  std::lock_guard<std::mutex> lock(mutex);
@@ -397,12 +466,15 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
397
466
  initialized = true;
398
467
  }
399
468
  #else
400
- UNUSED(initialized);
469
+ GGML_UNUSED(initialized);
401
470
  #endif
402
471
  auto sock = socket_connect(host.c_str(), port);
403
472
  if (sock == nullptr) {
404
473
  return nullptr;
405
474
  }
475
+ if (!check_server_version(sock)) {
476
+ return nullptr;
477
+ }
406
478
  GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
407
479
  sockets[endpoint] = sock;
408
480
  return sock;
@@ -412,22 +484,21 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
412
484
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
413
485
  rpc_msg_free_buffer_req request = {ctx->remote_ptr};
414
486
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
415
- GGML_ASSERT(status);
487
+ RPC_STATUS_ASSERT(status);
416
488
  delete ctx;
417
489
  }
418
490
 
419
491
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
420
492
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
421
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
422
- return ctx->base_cache[buffer];
493
+ if (ctx->base_ptr != nullptr) {
494
+ return ctx->base_ptr;
423
495
  }
424
496
  rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
425
497
  rpc_msg_buffer_get_base_rsp response;
426
498
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
427
- GGML_ASSERT(status);
428
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
429
- ctx->base_cache[buffer] = base_ptr;
430
- return base_ptr;
499
+ RPC_STATUS_ASSERT(status);
500
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
501
+ return ctx->base_ptr;
431
502
  }
432
503
 
433
504
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -456,29 +527,56 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
456
527
  result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
457
528
  result.view_offs = tensor->view_offs;
458
529
  result.data = reinterpret_cast<uint64_t>(tensor->data);
530
+
531
+ // Avoid sending uninitialized data over the wire
532
+ memset(result.name, 0, sizeof(result.name));
533
+ memset(result.padding, 0, sizeof(result.padding));
534
+
459
535
  snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
460
536
  return result;
461
537
  }
462
538
 
463
- static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
464
- UNUSED(buffer);
465
- if (ggml_is_quantized(tensor->type)) {
466
- // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
467
- GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
539
+ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
540
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
541
+
542
+ // CUDA backend on the server pads everything to 512 due to CUDA limitations.
543
+ // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
544
+ // In particular, only quantized tensors need padding
545
+ if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
546
+ rpc_msg_init_tensor_req request;
547
+
548
+ request.tensor = serialize_tensor(tensor);
549
+
550
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
551
+ RPC_STATUS_ASSERT(status);
468
552
  }
553
+ return GGML_STATUS_SUCCESS;
469
554
  }
470
555
 
471
556
  static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
472
557
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
473
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
558
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
559
+ if (size > HASH_THRESHOLD) {
560
+ rpc_msg_set_tensor_hash_req request;
561
+ request.tensor = rpc_tensor;
562
+ request.offset = offset;
563
+ request.hash = fnv_hash((const uint8_t*)data, size);
564
+ rpc_msg_set_tensor_hash_rsp response;
565
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
566
+ RPC_STATUS_ASSERT(status);
567
+ if (response.result) {
568
+ // the server has the same data, no need to send it
569
+ return;
570
+ }
571
+ }
572
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
474
573
  size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
475
574
  std::vector<uint8_t> input(input_size, 0);
476
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
477
575
  memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
478
576
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
479
577
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
480
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
481
- GGML_ASSERT(status);
578
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
579
+ RPC_STATUS_ASSERT(status);
482
580
  }
483
581
 
484
582
  static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -488,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
488
586
  request.offset = offset;
489
587
  request.size = size;
490
588
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
491
- GGML_ASSERT(status);
589
+ RPC_STATUS_ASSERT(status);
492
590
  }
493
591
 
494
592
  static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -506,7 +604,7 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
506
604
  request.dst = serialize_tensor(dst);
507
605
  rpc_msg_copy_tensor_rsp response;
508
606
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
509
- GGML_ASSERT(status);
607
+ RPC_STATUS_ASSERT(status);
510
608
  return response.result;
511
609
  }
512
610
 
@@ -514,7 +612,7 @@ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
514
612
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
515
613
  rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
516
614
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
517
- GGML_ASSERT(status);
615
+ RPC_STATUS_ASSERT(status);
518
616
  }
519
617
 
520
618
  static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
@@ -540,11 +638,11 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
540
638
  rpc_msg_alloc_buffer_rsp response;
541
639
  auto sock = get_socket(buft_ctx->endpoint);
542
640
  bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
543
- GGML_ASSERT(status);
641
+ RPC_STATUS_ASSERT(status);
544
642
  if (response.remote_ptr != 0) {
545
643
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
546
644
  ggml_backend_rpc_buffer_interface,
547
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
645
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
548
646
  response.remote_size);
549
647
  return buffer;
550
648
  } else {
@@ -555,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
555
653
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
556
654
  rpc_msg_get_alignment_rsp response;
557
655
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
558
- GGML_ASSERT(status);
656
+ RPC_STATUS_ASSERT(status);
559
657
  return response.alignment;
560
658
  }
561
659
 
@@ -567,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
567
665
  static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
568
666
  rpc_msg_get_max_size_rsp response;
569
667
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
570
- GGML_ASSERT(status);
668
+ RPC_STATUS_ASSERT(status);
571
669
  return response.max_size;
572
670
  }
573
671
 
@@ -577,8 +675,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
577
675
  }
578
676
 
579
677
  static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
580
- UNUSED(buft);
581
- return ggml_nbytes(tensor);
678
+ // See comments in init_tensor.
679
+ if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
680
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
681
+ auto sock = get_socket(buft_ctx->endpoint);
682
+
683
+ rpc_msg_get_alloc_size_req request;
684
+
685
+ request.tensor = serialize_tensor(tensor);
686
+
687
+ rpc_msg_get_alloc_size_rsp response;
688
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
689
+ RPC_STATUS_ASSERT(status);
690
+
691
+ return response.alloc_size;
692
+ } else {
693
+ return ggml_nbytes(tensor);
694
+ }
582
695
  }
583
696
 
584
697
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -603,7 +716,7 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
603
716
  }
604
717
 
605
718
  static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
606
- UNUSED(backend);
719
+ GGML_UNUSED(backend);
607
720
  // this is no-op because we don't have any async operations
608
721
  }
609
722
 
@@ -651,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
651
764
  rpc_msg_graph_compute_rsp response;
652
765
  auto sock = get_socket(rpc_ctx->endpoint);
653
766
  bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
654
- GGML_ASSERT(status);
767
+ RPC_STATUS_ASSERT(status);
655
768
  return (enum ggml_status)response.result;
656
769
  }
657
770
 
@@ -725,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
725
838
  static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
726
839
  rpc_msg_get_device_memory_rsp response;
727
840
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
728
- GGML_ASSERT(status);
841
+ RPC_STATUS_ASSERT(status);
729
842
  *free = response.free_mem;
730
843
  *total = response.total_mem;
731
844
  }
@@ -744,9 +857,12 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
744
857
 
745
858
  class rpc_server {
746
859
  public:
747
- rpc_server(ggml_backend_t backend) : backend(backend) {}
860
+ rpc_server(ggml_backend_t backend, const char * cache_dir)
861
+ : backend(backend), cache_dir(cache_dir) {
862
+ }
748
863
  ~rpc_server();
749
864
 
865
+ void hello(rpc_msg_hello_rsp & response);
750
866
  void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
751
867
  void get_alignment(rpc_msg_get_alignment_rsp & response);
752
868
  void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -754,11 +870,15 @@ public:
754
870
  bool free_buffer(const rpc_msg_free_buffer_req & request);
755
871
  bool buffer_clear(const rpc_msg_buffer_clear_req & request);
756
872
  bool set_tensor(const std::vector<uint8_t> & input);
873
+ bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
757
874
  bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
758
875
  bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759
876
  bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
877
+ bool init_tensor(const rpc_msg_init_tensor_req & request);
878
+ bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
760
879
 
761
880
  private:
881
+ bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
762
882
  ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
763
883
  ggml_tensor * create_node(uint64_t id,
764
884
  struct ggml_context * ctx,
@@ -767,9 +887,47 @@ private:
767
887
 
768
888
 
769
889
  ggml_backend_t backend;
890
+ const char * cache_dir;
770
891
  std::unordered_set<ggml_backend_buffer_t> buffers;
771
892
  };
772
893
 
894
+ void rpc_server::hello(rpc_msg_hello_rsp & response) {
895
+ response.major = RPC_PROTO_MAJOR_VERSION;
896
+ response.minor = RPC_PROTO_MINOR_VERSION;
897
+ response.patch = RPC_PROTO_PATCH_VERSION;
898
+ GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
899
+ }
900
+
901
+ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
902
+ ggml_backend_buffer_type_t buft;
903
+ struct ggml_init_params params {
904
+ /*.mem_size =*/ ggml_tensor_overhead(),
905
+ /*.mem_buffer =*/ NULL,
906
+ /*.no_alloc =*/ true,
907
+ };
908
+
909
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
910
+ GGML_ASSERT(ctx_ptr != nullptr);
911
+ ggml_context * ctx = ctx_ptr.get();
912
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
913
+
914
+ if (tensor == nullptr) {
915
+ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
916
+ return false;
917
+ }
918
+
919
+ if (tensor->buffer == nullptr) {
920
+ //No buffer allocated.
921
+ buft = ggml_backend_get_default_buffer_type(backend);
922
+ } else {
923
+ buft = tensor->buffer->buft;
924
+ }
925
+
926
+ response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
927
+
928
+ return true;
929
+ }
930
+
773
931
  void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
774
932
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
775
933
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
@@ -781,7 +939,7 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
781
939
  GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
782
940
  buffers.insert(buffer);
783
941
  } else {
784
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
942
+ GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
785
943
  }
786
944
  }
787
945
 
@@ -803,7 +961,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp
803
961
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
804
962
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
805
963
  if (buffers.find(buffer) == buffers.end()) {
806
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
964
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
807
965
  return false;
808
966
  }
809
967
  void * base = ggml_backend_buffer_get_base(buffer);
@@ -815,7 +973,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
815
973
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
816
974
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
817
975
  if (buffers.find(buffer) == buffers.end()) {
818
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
976
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
819
977
  return false;
820
978
  }
821
979
  ggml_backend_buffer_free(buffer);
@@ -827,7 +985,7 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
827
985
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
828
986
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
829
987
  if (buffers.find(buffer) == buffers.end()) {
830
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
988
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
831
989
  return false;
832
990
  }
833
991
  ggml_backend_buffer_clear(buffer, request.value);
@@ -835,8 +993,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
835
993
  }
836
994
 
837
995
  ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
996
+ // Validate tensor type before using it
997
+ if (tensor->type >= GGML_TYPE_COUNT) {
998
+ GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
999
+ return nullptr;
1000
+ }
1001
+
838
1002
  ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
839
1003
  tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1004
+
1005
+ // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
1006
+ if (result == nullptr) {
1007
+ GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
1008
+ return nullptr;
1009
+ }
1010
+
840
1011
  for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
841
1012
  result->nb[i] = tensor->nb[i];
842
1013
  }
@@ -880,11 +1051,12 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
880
1051
  /*.mem_buffer =*/ NULL,
881
1052
  /*.no_alloc =*/ true,
882
1053
  };
883
- struct ggml_context * ctx = ggml_init(params);
1054
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1055
+ GGML_ASSERT(ctx_ptr != nullptr);
1056
+ ggml_context * ctx = ctx_ptr.get();
884
1057
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
885
1058
  if (tensor == nullptr) {
886
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
887
- ggml_free(ctx);
1059
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
888
1060
  return false;
889
1061
  }
890
1062
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
@@ -895,13 +1067,118 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
895
1067
  const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
896
1068
 
897
1069
  if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
898
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1070
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1071
+ __func__, in_tensor->data, offset, size, p0, p1);
1072
+ return false;
899
1073
  }
900
1074
  }
901
1075
 
902
1076
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1077
+ if (cache_dir && size > HASH_THRESHOLD) {
1078
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
1079
+ char hash_str[17];
1080
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1081
+ // save to cache_dir/hash_str
1082
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1083
+ std::ofstream ofs(cache_file, std::ios::binary);
1084
+ ofs.write((const char *)data, size);
1085
+ printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1086
+ }
903
1087
  ggml_backend_tensor_set(tensor, data, offset, size);
904
- ggml_free(ctx);
1088
+ return true;
1089
+ }
1090
+
1091
+ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1092
+ if (!cache_dir) {
1093
+ return false;
1094
+ }
1095
+ char hash_str[17];
1096
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1097
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1098
+ if (!fs::exists(cache_file)) {
1099
+ return false;
1100
+ }
1101
+ std::ifstream ifs(cache_file, std::ios::binary);
1102
+ ifs.seekg(0, std::ios::end);
1103
+ size_t size = ifs.tellg();
1104
+ ifs.seekg(0, std::ios::beg);
1105
+ data.resize(size);
1106
+ ifs.read((char *)data.data(), size);
1107
+ return true;
1108
+ }
1109
+
1110
+ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
1111
+ {
1112
+ std::vector<uint8_t> cached_file;
1113
+ if (!get_cached_file(request.hash, cached_file)) {
1114
+ response.result = 0;
1115
+ return true;
1116
+ }
1117
+ size_t size = cached_file.size();
1118
+ struct ggml_init_params params {
1119
+ /*.mem_size =*/ ggml_tensor_overhead(),
1120
+ /*.mem_buffer =*/ NULL,
1121
+ /*.no_alloc =*/ true,
1122
+ };
1123
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1124
+ GGML_ASSERT(ctx_ptr != nullptr);
1125
+ ggml_context * ctx = ctx_ptr.get();
1126
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1127
+ if (tensor == nullptr) {
1128
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1129
+ return false;
1130
+ }
1131
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1132
+ __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1133
+
1134
+ // sanitize tensor->data
1135
+ {
1136
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1137
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1138
+
1139
+ if (request.tensor.data + request.offset < p0
1140
+ || request.tensor.data + request.offset >= p1
1141
+ || size > (p1 - request.tensor.data - request.offset)) {
1142
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1143
+ __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
1144
+ return false;
1145
+ }
1146
+ }
1147
+ ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
1148
+ response.result = 1;
1149
+ return true;
1150
+ }
1151
+
1152
+ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
1153
+ struct ggml_init_params params {
1154
+ /*.mem_size =*/ ggml_tensor_overhead(),
1155
+ /*.mem_buffer =*/ NULL,
1156
+ /*.no_alloc =*/ true,
1157
+ };
1158
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1159
+ GGML_ASSERT(ctx_ptr != nullptr);
1160
+ ggml_context * ctx = ctx_ptr.get();
1161
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1162
+ if (tensor == nullptr) {
1163
+ GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
1164
+ return false;
1165
+ }
1166
+
1167
+ // Call the backend's buffer_init_tensor function
1168
+ ggml_backend_buffer_t buffer = tensor->buffer;
1169
+ if (buffer && buffer->iface.init_tensor) {
1170
+ buffer->iface.init_tensor(buffer, tensor);
1171
+ } else {
1172
+ GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
1173
+ }
1174
+
1175
+ if (tensor->extra != nullptr) {
1176
+ // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1177
+ // Currently unimplemented.
1178
+ GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
1179
+ return false;
1180
+ }
1181
+
905
1182
  return true;
906
1183
  }
907
1184
 
@@ -911,11 +1188,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
911
1188
  /*.mem_buffer =*/ NULL,
912
1189
  /*.no_alloc =*/ true,
913
1190
  };
914
- struct ggml_context * ctx = ggml_init(params);
1191
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1192
+ GGML_ASSERT(ctx_ptr != nullptr);
1193
+ ggml_context * ctx = ctx_ptr.get();
915
1194
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
916
1195
  if (tensor == nullptr) {
917
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
918
- ggml_free(ctx);
1196
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
919
1197
  return false;
920
1198
  }
921
1199
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
@@ -928,13 +1206,14 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
928
1206
  if (request.tensor.data + request.offset < p0 ||
929
1207
  request.tensor.data + request.offset >= p1 ||
930
1208
  request.size > (p1 - request.tensor.data - request.offset)) {
931
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1209
+ GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1210
+ __func__, request.tensor.data, request.offset, request.size, p0, p1);
1211
+ return false;
932
1212
  }
933
1213
  }
934
1214
 
935
1215
  response.resize(request.size, 0);
936
1216
  ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
937
- ggml_free(ctx);
938
1217
  return true;
939
1218
  }
940
1219
 
@@ -944,17 +1223,38 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
944
1223
  /*.mem_buffer =*/ NULL,
945
1224
  /*.no_alloc =*/ true,
946
1225
  };
947
- struct ggml_context * ctx = ggml_init(params);
1226
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1227
+ GGML_ASSERT(ctx_ptr != nullptr);
1228
+ ggml_context * ctx = ctx_ptr.get();
1229
+
948
1230
  ggml_tensor * src = deserialize_tensor(ctx, &request.src);
949
1231
  ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
950
1232
  if (src == nullptr || dst == nullptr) {
951
- GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
952
- ggml_free(ctx);
1233
+ GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1234
+ return false;
1235
+ }
1236
+
1237
+ uint64_t src_size = (uint64_t) ggml_nbytes(src);
1238
+ uint64_t dst_data = (uint64_t) dst->data;
1239
+ uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
1240
+ uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1241
+
1242
+ if (dst_data + src_size > dst_base + dst_buf_sz) {
1243
+ GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1244
+ " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1245
+ " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1246
+ __func__,
1247
+ dst_data,
1248
+ dst_data + src_size,
1249
+ dst_base,
1250
+ dst_base + dst_buf_sz);
953
1251
  return false;
954
1252
  }
955
- GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
1253
+
1254
+ GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
1255
+ __func__, (void*) src->buffer, (void*) dst->buffer);
1256
+
956
1257
  response.result = ggml_backend_buffer_copy_tensor(src, dst);
957
- ggml_free(ctx);
958
1258
  return true;
959
1259
  }
960
1260
 
@@ -962,22 +1262,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
962
1262
  struct ggml_context * ctx,
963
1263
  const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
964
1264
  std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
965
- if (id == 0) {
966
- return nullptr;
967
- }
968
1265
  if (tensor_map.find(id) != tensor_map.end()) {
969
1266
  return tensor_map[id];
970
1267
  }
971
- const rpc_tensor * tensor = tensor_ptrs.at(id);
1268
+ // Safely find the tensor pointer
1269
+ auto it_ptr = tensor_ptrs.find(id);
1270
+ if (it_ptr == tensor_ptrs.end()) {
1271
+ return nullptr;
1272
+ }
1273
+ const rpc_tensor * tensor = it_ptr->second;
1274
+
972
1275
  struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
973
1276
  if (result == nullptr) {
974
1277
  return nullptr;
975
1278
  }
976
1279
  tensor_map[id] = result;
977
1280
  for (int i = 0; i < GGML_MAX_SRC; i++) {
978
- result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1281
+ // Check if the source ID is 0 before calling create_node recursively
1282
+ if (tensor->src[i] == 0) {
1283
+ result->src[i] = nullptr;
1284
+ } else {
1285
+ result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1286
+ // If the recursive call failed for a non-zero ID, propagate the error
1287
+ if (result->src[i] == nullptr) {
1288
+ GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1289
+ __func__, i, tensor->src[i], id);
1290
+ // Must return nullptr to signal failure up the call stack
1291
+ return nullptr;
1292
+ }
1293
+ }
1294
+ }
1295
+
1296
+ // Handle view_src similarly
1297
+ if (tensor->view_src == 0) {
1298
+ result->view_src = nullptr;
1299
+ } else {
1300
+ result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1301
+ // If the recursive call failed for a non-zero ID, propagate the error
1302
+ if (result->view_src == nullptr) {
1303
+ GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1304
+ __func__, tensor->view_src, id);
1305
+ // Must return nullptr to signal failure up the call stack
1306
+ return nullptr;
1307
+ }
979
1308
  }
980
- result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
981
1309
  result->view_offs = tensor->view_offs;
982
1310
  return result;
983
1311
  }
@@ -1003,12 +1331,15 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1003
1331
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1004
1332
 
1005
1333
  size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1334
+
1006
1335
  struct ggml_init_params params = {
1007
1336
  /*.mem_size =*/ buf_size,
1008
1337
  /*.mem_buffer =*/ NULL,
1009
1338
  /*.no_alloc =*/ true,
1010
1339
  };
1011
- struct ggml_context * ctx = ggml_init(params);
1340
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1341
+ GGML_ASSERT(ctx_ptr != nullptr);
1342
+ ggml_context * ctx = ctx_ptr.get();
1012
1343
  struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1013
1344
  graph->n_nodes = n_nodes;
1014
1345
  std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
@@ -1020,10 +1351,17 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1020
1351
  int64_t id;
1021
1352
  memcpy(&id, &nodes[i], sizeof(id));
1022
1353
  graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1354
+
1355
+ // Check if create_node failed for a *non-zero* ID.
1356
+ // If id was 0, create_node returning nullptr is expected.
1357
+ // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1358
+ if (graph->nodes[i] == nullptr && id != 0) {
1359
+ GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1360
+ return false;
1361
+ }
1023
1362
  }
1024
1363
  ggml_status status = ggml_backend_graph_compute(backend, graph);
1025
1364
  response.result = status;
1026
- ggml_free(ctx);
1027
1365
  return true;
1028
1366
  }
1029
1367
 
@@ -1033,10 +1371,27 @@ rpc_server::~rpc_server() {
1033
1371
  }
1034
1372
  }
1035
1373
 
1036
- static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1037
- rpc_server server(backend);
1374
+ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1375
+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1376
+ rpc_server server(backend, cache_dir);
1377
+ uint8_t cmd;
1378
+ if (!recv_data(sockfd, &cmd, 1)) {
1379
+ return;
1380
+ }
1381
+ // the first command sent by the client must be HELLO
1382
+ if (cmd != RPC_CMD_HELLO) {
1383
+ fprintf(stderr, "Expected HELLO command, update client\n");
1384
+ return;
1385
+ }
1386
+ if (!recv_msg(sockfd, nullptr, 0)) {
1387
+ return;
1388
+ }
1389
+ rpc_msg_hello_rsp response;
1390
+ server.hello(response);
1391
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1392
+ return;
1393
+ }
1038
1394
  while (true) {
1039
- uint8_t cmd;
1040
1395
  if (!recv_data(sockfd, &cmd, 1)) {
1041
1396
  break;
1042
1397
  }
@@ -1046,6 +1401,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1046
1401
  break;
1047
1402
  }
1048
1403
  switch (cmd) {
1404
+ case RPC_CMD_HELLO: {
1405
+ // HELLO command is handled above
1406
+ return;
1407
+ }
1049
1408
  case RPC_CMD_ALLOC_BUFFER: {
1050
1409
  rpc_msg_alloc_buffer_req request;
1051
1410
  if (!recv_msg(sockfd, &request, sizeof(request))) {
@@ -1058,6 +1417,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1058
1417
  }
1059
1418
  break;
1060
1419
  }
1420
+ case RPC_CMD_GET_ALLOC_SIZE: {
1421
+ rpc_msg_get_alloc_size_req request;
1422
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1423
+ return;
1424
+ }
1425
+ rpc_msg_get_alloc_size_rsp response;
1426
+ if (!server.get_alloc_size(request, response)) {
1427
+ return;
1428
+ }
1429
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1430
+ return;
1431
+ }
1432
+ break;
1433
+ }
1061
1434
  case RPC_CMD_GET_ALIGNMENT: {
1062
1435
  if (!recv_msg(sockfd, nullptr, 0)) {
1063
1436
  return;
@@ -1128,6 +1501,30 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1128
1501
  if (!server.set_tensor(input)) {
1129
1502
  return;
1130
1503
  }
1504
+ break;
1505
+ }
1506
+ case RPC_CMD_SET_TENSOR_HASH: {
1507
+ rpc_msg_set_tensor_hash_req request;
1508
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1509
+ return;
1510
+ }
1511
+ rpc_msg_set_tensor_hash_rsp response;
1512
+ if (!server.set_tensor_hash(request, response)) {
1513
+ return;
1514
+ }
1515
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1516
+ return;
1517
+ }
1518
+ break;
1519
+ }
1520
+ case RPC_CMD_INIT_TENSOR: {
1521
+ rpc_msg_init_tensor_req request;
1522
+ if (!recv_msg(sockfd, &request,sizeof(request))) {
1523
+ return;
1524
+ }
1525
+ if (!server.init_tensor(request)) {
1526
+ return;
1527
+ }
1131
1528
  if (!send_msg(sockfd, nullptr, 0)) {
1132
1529
  return;
1133
1530
  }
@@ -1195,7 +1592,17 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1195
1592
  }
1196
1593
  }
1197
1594
 
1198
- void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1595
+ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
1596
+ const char * cache_dir,
1597
+ size_t free_mem, size_t total_mem) {
1598
+ printf("Starting RPC server v%d.%d.%d\n",
1599
+ RPC_PROTO_MAJOR_VERSION,
1600
+ RPC_PROTO_MINOR_VERSION,
1601
+ RPC_PROTO_PATCH_VERSION);
1602
+ printf(" endpoint : %s\n", endpoint);
1603
+ printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
1604
+ printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
1605
+
1199
1606
  std::string host;
1200
1607
  int port;
1201
1608
  if (!parse_endpoint(endpoint, host, port)) {
@@ -1224,7 +1631,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
1224
1631
  }
1225
1632
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1226
1633
  fflush(stdout);
1227
- rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1634
+ rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
1228
1635
  printf("Client connection closed\n");
1229
1636
  fflush(stdout);
1230
1637
  }
@@ -1257,14 +1664,14 @@ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t *
1257
1664
 
1258
1665
  ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1259
1666
 
1260
- UNUSED(dev);
1667
+ GGML_UNUSED(dev);
1261
1668
  }
1262
1669
 
1263
1670
  static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1264
1671
  // TODO: obtain value from the server
1265
1672
  return GGML_BACKEND_DEVICE_TYPE_GPU;
1266
1673
 
1267
- UNUSED(dev);
1674
+ GGML_UNUSED(dev);
1268
1675
  }
1269
1676
 
1270
1677
  static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
@@ -1285,7 +1692,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
1285
1692
 
1286
1693
  return ggml_backend_rpc_init(ctx->endpoint.c_str());
1287
1694
 
1288
- UNUSED(params);
1695
+ GGML_UNUSED(params);
1289
1696
  }
1290
1697
 
1291
1698
  static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -1293,12 +1700,12 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b
1293
1700
 
1294
1701
  return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1295
1702
 
1296
- UNUSED(dev);
1703
+ GGML_UNUSED(dev);
1297
1704
  }
1298
1705
 
1299
1706
  static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1300
- UNUSED(dev);
1301
- UNUSED(op);
1707
+ GGML_UNUSED(dev);
1708
+ GGML_UNUSED(op);
1302
1709
  //TODO: call the remote backend and cache the results
1303
1710
  return true;
1304
1711
  }
@@ -1335,29 +1742,32 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1335
1742
  static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1336
1743
  return "RPC";
1337
1744
 
1338
- UNUSED(reg);
1745
+ GGML_UNUSED(reg);
1339
1746
  }
1340
1747
 
1341
1748
  static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1342
1749
  return 0;
1343
1750
 
1344
- UNUSED(reg);
1751
+ GGML_UNUSED(reg);
1345
1752
  }
1346
1753
 
1347
1754
  static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1348
1755
  GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1349
1756
 
1350
- UNUSED(reg);
1351
- UNUSED(index);
1757
+ GGML_UNUSED(reg);
1758
+ GGML_UNUSED(index);
1352
1759
  }
1353
1760
 
1354
1761
  static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1355
1762
  if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1356
1763
  return (void *)ggml_backend_rpc_add_device;
1357
1764
  }
1765
+ if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
1766
+ return (void *)ggml_backend_rpc_start_server;
1767
+ }
1358
1768
  return NULL;
1359
1769
 
1360
- UNUSED(reg);
1770
+ GGML_UNUSED(reg);
1361
1771
  }
1362
1772
 
1363
1773
  static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {