whispercpp 1.3.0 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -0
- data/LICENSE +1 -1
- data/README.md +216 -424
- data/Rakefile +79 -11
- data/ext/.gitignore +11 -0
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +18 -26
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +27 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/sources/ggml/include/ggml-alloc.h +76 -0
- data/ext/sources/ggml/include/ggml-backend.h +354 -0
- data/ext/sources/ggml/include/ggml-blas.h +25 -0
- data/ext/sources/ggml/include/ggml-cann.h +123 -0
- data/ext/sources/ggml/include/ggml-cpp.h +39 -0
- data/ext/sources/ggml/include/ggml-cpu.h +143 -0
- data/ext/sources/ggml/include/ggml-cuda.h +47 -0
- data/ext/sources/ggml/include/ggml-kompute.h +50 -0
- data/ext/sources/ggml/include/ggml-metal.h +66 -0
- data/ext/sources/ggml/include/ggml-opencl.h +26 -0
- data/ext/sources/ggml/include/ggml-opt.h +237 -0
- data/ext/sources/ggml/include/ggml-rpc.h +33 -0
- data/ext/sources/ggml/include/ggml-sycl.h +49 -0
- data/ext/sources/ggml/include/ggml-vulkan.h +29 -0
- data/ext/{ggml.h → sources/ggml/include/ggml.h} +621 -821
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/sources/ggml/src/ggml-alloc.c +1042 -0
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/sources/ggml/src/ggml-amx/common.h +94 -0
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/sources/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +255 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +586 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +2011 -0
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +181 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +3193 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +420 -0
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +2606 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +234 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/sources/ggml/src/ggml-common.h +1857 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +221 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/sources/ggml/src/ggml-cpu/cpu-feats-x86.cpp +327 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +508 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +13747 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +671 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +15 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +243 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +140 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/sources/ggml/src/ggml-impl.h +601 -0
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +5998 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +7089 -0
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1037 -0
- data/ext/sources/ggml/src/ggml-quants.c +5232 -0
- data/ext/sources/ggml/src/ggml-quants.h +100 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +1813 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +195 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +101 -0
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +623 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +1162 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +4493 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +3030 -0
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +501 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +47 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +261 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-threading.cpp +12 -0
- data/ext/sources/ggml/src/ggml-threading.h +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +10700 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +751 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/sources/ggml/src/ggml.c +6550 -0
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{whisper.h → sources/include/whisper.h} +91 -24
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.h +158 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +226 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.h +154 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +222 -0
- data/ext/sources/src/coreml/whisper-encoder.h +26 -0
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/sources/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{whisper.cpp → sources/src/whisper.cpp} +2535 -835
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +34 -0
- data/lib/whisper/model/uri.rb +178 -0
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +35 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +202 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +109 -0
- data/tests/test_package.rb +46 -0
- data/tests/test_params.rb +297 -0
- data/tests/test_segment.rb +74 -0
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +212 -124
- data/whispercpp.gemspec +37 -0
- metadata +794 -13
- data/ext/dr_wav.h +0 -6434
- data/ext/ggml.c +0 -21755
- data/ext/ruby_whisper.cpp +0 -426
@@ -0,0 +1,1813 @@
|
|
1
|
+
#include "ggml-rpc.h"
|
2
|
+
#include "ggml-impl.h"
|
3
|
+
#include "ggml-backend-impl.h"
|
4
|
+
#include "ggml-cpp.h"
|
5
|
+
|
6
|
+
#include <cinttypes>
|
7
|
+
#include <string>
|
8
|
+
#include <vector>
|
9
|
+
#include <memory>
|
10
|
+
#include <mutex>
|
11
|
+
#include <unordered_map>
|
12
|
+
#include <unordered_set>
|
13
|
+
#ifdef _WIN32
|
14
|
+
# define WIN32_LEAN_AND_MEAN
|
15
|
+
# ifndef NOMINMAX
|
16
|
+
# define NOMINMAX
|
17
|
+
# endif
|
18
|
+
# include <windows.h>
|
19
|
+
# include <winsock2.h>
|
20
|
+
#else
|
21
|
+
# include <arpa/inet.h>
|
22
|
+
# include <sys/socket.h>
|
23
|
+
# include <sys/types.h>
|
24
|
+
# include <netinet/in.h>
|
25
|
+
# include <netinet/tcp.h>
|
26
|
+
# include <netdb.h>
|
27
|
+
# include <unistd.h>
|
28
|
+
#endif
|
29
|
+
#include <cstring>
|
30
|
+
#include <fstream>
|
31
|
+
#include <filesystem>
|
32
|
+
|
33
|
+
namespace fs = std::filesystem;
|
34
|
+
|
35
|
+
#ifdef _WIN32
|
36
|
+
typedef SOCKET sockfd_t;
|
37
|
+
using ssize_t = __int64;
|
38
|
+
#else
|
39
|
+
typedef int sockfd_t;
|
40
|
+
#endif
|
41
|
+
|
42
|
+
// cross-platform socket
|
43
|
+
struct socket_t {
|
44
|
+
sockfd_t fd;
|
45
|
+
socket_t(sockfd_t fd) : fd(fd) {}
|
46
|
+
~socket_t() {
|
47
|
+
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
|
48
|
+
#ifdef _WIN32
|
49
|
+
closesocket(this->fd);
|
50
|
+
#else
|
51
|
+
close(this->fd);
|
52
|
+
#endif
|
53
|
+
}
|
54
|
+
};
|
55
|
+
|
56
|
+
// all RPC structures must be packed
|
57
|
+
#pragma pack(push, 1)
|
58
|
+
// ggml_tensor is serialized into rpc_tensor
|
59
|
+
struct rpc_tensor {
|
60
|
+
uint64_t id;
|
61
|
+
uint32_t type;
|
62
|
+
uint64_t buffer;
|
63
|
+
uint32_t ne[GGML_MAX_DIMS];
|
64
|
+
uint32_t nb[GGML_MAX_DIMS];
|
65
|
+
uint32_t op;
|
66
|
+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
67
|
+
int32_t flags;
|
68
|
+
uint64_t src[GGML_MAX_SRC];
|
69
|
+
uint64_t view_src;
|
70
|
+
uint64_t view_offs;
|
71
|
+
uint64_t data;
|
72
|
+
char name[GGML_MAX_NAME];
|
73
|
+
|
74
|
+
char padding[4];
|
75
|
+
};
|
76
|
+
|
77
|
+
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
78
|
+
|
79
|
+
// RPC commands
|
80
|
+
enum rpc_cmd {
|
81
|
+
RPC_CMD_ALLOC_BUFFER = 0,
|
82
|
+
RPC_CMD_GET_ALIGNMENT,
|
83
|
+
RPC_CMD_GET_MAX_SIZE,
|
84
|
+
RPC_CMD_BUFFER_GET_BASE,
|
85
|
+
RPC_CMD_FREE_BUFFER,
|
86
|
+
RPC_CMD_BUFFER_CLEAR,
|
87
|
+
RPC_CMD_SET_TENSOR,
|
88
|
+
RPC_CMD_SET_TENSOR_HASH,
|
89
|
+
RPC_CMD_GET_TENSOR,
|
90
|
+
RPC_CMD_COPY_TENSOR,
|
91
|
+
RPC_CMD_GRAPH_COMPUTE,
|
92
|
+
RPC_CMD_GET_DEVICE_MEMORY,
|
93
|
+
RPC_CMD_INIT_TENSOR,
|
94
|
+
RPC_CMD_GET_ALLOC_SIZE,
|
95
|
+
RPC_CMD_HELLO,
|
96
|
+
RPC_CMD_COUNT,
|
97
|
+
};
|
98
|
+
|
99
|
+
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
100
|
+
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
101
|
+
|
102
|
+
struct rpc_msg_hello_rsp {
|
103
|
+
uint8_t major;
|
104
|
+
uint8_t minor;
|
105
|
+
uint8_t patch;
|
106
|
+
};
|
107
|
+
|
108
|
+
struct rpc_msg_get_alloc_size_req {
|
109
|
+
rpc_tensor tensor;
|
110
|
+
};
|
111
|
+
|
112
|
+
struct rpc_msg_get_alloc_size_rsp {
|
113
|
+
uint64_t alloc_size;
|
114
|
+
};
|
115
|
+
|
116
|
+
struct rpc_msg_init_tensor_req {
|
117
|
+
rpc_tensor tensor;
|
118
|
+
};
|
119
|
+
|
120
|
+
struct rpc_msg_alloc_buffer_req {
|
121
|
+
uint64_t size;
|
122
|
+
};
|
123
|
+
|
124
|
+
struct rpc_msg_alloc_buffer_rsp {
|
125
|
+
uint64_t remote_ptr;
|
126
|
+
uint64_t remote_size;
|
127
|
+
};
|
128
|
+
|
129
|
+
struct rpc_msg_get_alignment_rsp {
|
130
|
+
uint64_t alignment;
|
131
|
+
};
|
132
|
+
|
133
|
+
struct rpc_msg_get_max_size_rsp {
|
134
|
+
uint64_t max_size;
|
135
|
+
};
|
136
|
+
|
137
|
+
struct rpc_msg_buffer_get_base_req {
|
138
|
+
uint64_t remote_ptr;
|
139
|
+
};
|
140
|
+
|
141
|
+
struct rpc_msg_buffer_get_base_rsp {
|
142
|
+
uint64_t base_ptr;
|
143
|
+
};
|
144
|
+
|
145
|
+
struct rpc_msg_free_buffer_req {
|
146
|
+
uint64_t remote_ptr;
|
147
|
+
};
|
148
|
+
|
149
|
+
struct rpc_msg_buffer_clear_req {
|
150
|
+
uint64_t remote_ptr;
|
151
|
+
uint8_t value;
|
152
|
+
};
|
153
|
+
|
154
|
+
struct rpc_msg_set_tensor_hash_req {
|
155
|
+
rpc_tensor tensor;
|
156
|
+
uint64_t offset;
|
157
|
+
uint64_t hash;
|
158
|
+
};
|
159
|
+
|
160
|
+
struct rpc_msg_set_tensor_hash_rsp {
|
161
|
+
uint8_t result;
|
162
|
+
};
|
163
|
+
|
164
|
+
struct rpc_msg_get_tensor_req {
|
165
|
+
rpc_tensor tensor;
|
166
|
+
uint64_t offset;
|
167
|
+
uint64_t size;
|
168
|
+
};
|
169
|
+
|
170
|
+
struct rpc_msg_copy_tensor_req {
|
171
|
+
rpc_tensor src;
|
172
|
+
rpc_tensor dst;
|
173
|
+
};
|
174
|
+
|
175
|
+
struct rpc_msg_copy_tensor_rsp {
|
176
|
+
uint8_t result;
|
177
|
+
};
|
178
|
+
|
179
|
+
struct rpc_msg_graph_compute_rsp {
|
180
|
+
uint8_t result;
|
181
|
+
};
|
182
|
+
|
183
|
+
struct rpc_msg_get_device_memory_rsp {
|
184
|
+
uint64_t free_mem;
|
185
|
+
uint64_t total_mem;
|
186
|
+
};
|
187
|
+
#pragma pack(pop)
|
188
|
+
|
189
|
+
// RPC data structures
|
190
|
+
|
191
|
+
static ggml_guid_t ggml_backend_rpc_guid() {
|
192
|
+
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
|
193
|
+
return &guid;
|
194
|
+
}
|
195
|
+
|
196
|
+
struct ggml_backend_rpc_buffer_type_context {
|
197
|
+
std::string endpoint;
|
198
|
+
std::string name;
|
199
|
+
size_t alignment;
|
200
|
+
size_t max_size;
|
201
|
+
};
|
202
|
+
|
203
|
+
struct ggml_backend_rpc_context {
|
204
|
+
std::string endpoint;
|
205
|
+
std::string name;
|
206
|
+
};
|
207
|
+
|
208
|
+
struct ggml_backend_rpc_buffer_context {
|
209
|
+
std::shared_ptr<socket_t> sock;
|
210
|
+
void * base_ptr;
|
211
|
+
uint64_t remote_ptr;
|
212
|
+
};
|
213
|
+
|
214
|
+
// RPC helper functions
|
215
|
+
|
216
|
+
// Computes FNV-1a hash of the data
|
217
|
+
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
|
218
|
+
const uint64_t fnv_prime = 0x100000001b3ULL;
|
219
|
+
uint64_t hash = 0xcbf29ce484222325ULL;
|
220
|
+
|
221
|
+
for (size_t i = 0; i < len; ++i) {
|
222
|
+
hash ^= data[i];
|
223
|
+
hash *= fnv_prime;
|
224
|
+
}
|
225
|
+
return hash;
|
226
|
+
}
|
227
|
+
|
228
|
+
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
|
229
|
+
#ifdef _WIN32
|
230
|
+
if (fd == INVALID_SOCKET) {
|
231
|
+
return nullptr;
|
232
|
+
}
|
233
|
+
#else
|
234
|
+
if (fd < 0) {
|
235
|
+
return nullptr;
|
236
|
+
}
|
237
|
+
#endif
|
238
|
+
return std::make_shared<socket_t>(fd);
|
239
|
+
}
|
240
|
+
|
241
|
+
static bool set_no_delay(sockfd_t sockfd) {
|
242
|
+
int flag = 1;
|
243
|
+
// set TCP_NODELAY to disable Nagle's algorithm
|
244
|
+
int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
|
245
|
+
return ret == 0;
|
246
|
+
}
|
247
|
+
|
248
|
+
static bool set_reuse_addr(sockfd_t sockfd) {
|
249
|
+
int flag = 1;
|
250
|
+
int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
|
251
|
+
return ret == 0;
|
252
|
+
}
|
253
|
+
|
254
|
+
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
|
255
|
+
struct sockaddr_in addr;
|
256
|
+
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
257
|
+
auto sock_ptr = make_socket(sockfd);
|
258
|
+
if (sock_ptr == nullptr) {
|
259
|
+
return nullptr;
|
260
|
+
}
|
261
|
+
if (!set_no_delay(sockfd)) {
|
262
|
+
fprintf(stderr, "Failed to set TCP_NODELAY\n");
|
263
|
+
return nullptr;
|
264
|
+
}
|
265
|
+
addr.sin_family = AF_INET;
|
266
|
+
addr.sin_port = htons(port);
|
267
|
+
struct hostent * server = gethostbyname(host);
|
268
|
+
if (server == NULL) {
|
269
|
+
fprintf(stderr, "Cannot resolve host '%s'\n", host);
|
270
|
+
return nullptr;
|
271
|
+
}
|
272
|
+
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
|
273
|
+
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
|
274
|
+
return nullptr;
|
275
|
+
}
|
276
|
+
return sock_ptr;
|
277
|
+
}
|
278
|
+
|
279
|
+
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
|
280
|
+
auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
|
281
|
+
auto client_socket = make_socket(client_socket_fd);
|
282
|
+
if (client_socket == nullptr) {
|
283
|
+
return nullptr;
|
284
|
+
}
|
285
|
+
if (!set_no_delay(client_socket_fd)) {
|
286
|
+
fprintf(stderr, "Failed to set TCP_NODELAY\n");
|
287
|
+
return nullptr;
|
288
|
+
}
|
289
|
+
return client_socket;
|
290
|
+
}
|
291
|
+
|
292
|
+
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
|
293
|
+
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
294
|
+
auto sock = make_socket(sockfd);
|
295
|
+
if (sock == nullptr) {
|
296
|
+
return nullptr;
|
297
|
+
}
|
298
|
+
if (!set_reuse_addr(sockfd)) {
|
299
|
+
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
|
300
|
+
return nullptr;
|
301
|
+
}
|
302
|
+
if (inet_addr(host) == INADDR_NONE) {
|
303
|
+
fprintf(stderr, "Invalid host address: %s\n", host);
|
304
|
+
return nullptr;
|
305
|
+
}
|
306
|
+
struct sockaddr_in serv_addr;
|
307
|
+
serv_addr.sin_family = AF_INET;
|
308
|
+
serv_addr.sin_addr.s_addr = inet_addr(host);
|
309
|
+
serv_addr.sin_port = htons(port);
|
310
|
+
|
311
|
+
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
|
312
|
+
return nullptr;
|
313
|
+
}
|
314
|
+
if (listen(sockfd, 1) < 0) {
|
315
|
+
return nullptr;
|
316
|
+
}
|
317
|
+
return sock;
|
318
|
+
}
|
319
|
+
|
320
|
+
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
|
321
|
+
size_t bytes_sent = 0;
|
322
|
+
while (bytes_sent < size) {
|
323
|
+
ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
|
324
|
+
if (n < 0) {
|
325
|
+
return false;
|
326
|
+
}
|
327
|
+
bytes_sent += n;
|
328
|
+
}
|
329
|
+
return true;
|
330
|
+
}
|
331
|
+
|
332
|
+
static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
333
|
+
size_t bytes_recv = 0;
|
334
|
+
while (bytes_recv < size) {
|
335
|
+
ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
|
336
|
+
if (n <= 0) {
|
337
|
+
return false;
|
338
|
+
}
|
339
|
+
bytes_recv += n;
|
340
|
+
}
|
341
|
+
return true;
|
342
|
+
}
|
343
|
+
|
344
|
+
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
|
345
|
+
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
|
346
|
+
return false;
|
347
|
+
}
|
348
|
+
return send_data(sockfd, msg, msg_size);
|
349
|
+
}
|
350
|
+
|
351
|
+
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
|
352
|
+
uint64_t size;
|
353
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
354
|
+
return false;
|
355
|
+
}
|
356
|
+
if (size != msg_size) {
|
357
|
+
return false;
|
358
|
+
}
|
359
|
+
return recv_data(sockfd, msg, msg_size);
|
360
|
+
}
|
361
|
+
|
362
|
+
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
363
|
+
uint64_t size;
|
364
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
365
|
+
return false;
|
366
|
+
}
|
367
|
+
try {
|
368
|
+
input.resize(size);
|
369
|
+
} catch (const std::bad_alloc & e) {
|
370
|
+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
371
|
+
return false;
|
372
|
+
}
|
373
|
+
return recv_data(sockfd, input.data(), size);
|
374
|
+
}
|
375
|
+
|
376
|
+
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
377
|
+
size_t pos = endpoint.find(':');
|
378
|
+
if (pos == std::string::npos) {
|
379
|
+
return false;
|
380
|
+
}
|
381
|
+
host = endpoint.substr(0, pos);
|
382
|
+
port = std::stoi(endpoint.substr(pos + 1));
|
383
|
+
return true;
|
384
|
+
}
|
385
|
+
|
386
|
+
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
387
|
+
// No response
|
388
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
|
389
|
+
uint8_t cmd_byte = cmd;
|
390
|
+
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
391
|
+
return false;
|
392
|
+
}
|
393
|
+
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
394
|
+
return false;
|
395
|
+
}
|
396
|
+
if (!send_data(sock->fd, input, input_size)) {
|
397
|
+
return false;
|
398
|
+
}
|
399
|
+
return true;
|
400
|
+
}
|
401
|
+
|
402
|
+
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
403
|
+
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
404
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
405
|
+
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
|
406
|
+
return false;
|
407
|
+
}
|
408
|
+
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
409
|
+
// even if we do, we can skip sending output_size from the server for commands with known output size
|
410
|
+
uint64_t out_size;
|
411
|
+
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
|
412
|
+
return false;
|
413
|
+
}
|
414
|
+
if (out_size != output_size) {
|
415
|
+
return false;
|
416
|
+
}
|
417
|
+
if (!recv_data(sock->fd, output, output_size)) {
|
418
|
+
return false;
|
419
|
+
}
|
420
|
+
return true;
|
421
|
+
}
|
422
|
+
|
423
|
+
// RPC client-side implementation
|
424
|
+
|
425
|
+
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
426
|
+
rpc_msg_hello_rsp response;
|
427
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
428
|
+
GGML_ASSERT(status);
|
429
|
+
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
430
|
+
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
431
|
+
return false;
|
432
|
+
}
|
433
|
+
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
434
|
+
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
435
|
+
}
|
436
|
+
return true;
|
437
|
+
}
|
438
|
+
|
439
|
+
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
440
|
+
static std::mutex mutex;
|
441
|
+
std::lock_guard<std::mutex> lock(mutex);
|
442
|
+
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
443
|
+
static bool initialized = false;
|
444
|
+
|
445
|
+
auto it = sockets.find(endpoint);
|
446
|
+
if (it != sockets.end()) {
|
447
|
+
if (auto sock = it->second.lock()) {
|
448
|
+
return sock;
|
449
|
+
}
|
450
|
+
}
|
451
|
+
std::string host;
|
452
|
+
int port;
|
453
|
+
if (!parse_endpoint(endpoint, host, port)) {
|
454
|
+
return nullptr;
|
455
|
+
}
|
456
|
+
#ifdef _WIN32
|
457
|
+
if (!initialized) {
|
458
|
+
WSADATA wsaData;
|
459
|
+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
460
|
+
if (res != 0) {
|
461
|
+
return nullptr;
|
462
|
+
}
|
463
|
+
initialized = true;
|
464
|
+
}
|
465
|
+
#else
|
466
|
+
GGML_UNUSED(initialized);
|
467
|
+
#endif
|
468
|
+
auto sock = socket_connect(host.c_str(), port);
|
469
|
+
if (sock == nullptr) {
|
470
|
+
return nullptr;
|
471
|
+
}
|
472
|
+
if (!check_server_version(sock)) {
|
473
|
+
return nullptr;
|
474
|
+
}
|
475
|
+
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
476
|
+
sockets[endpoint] = sock;
|
477
|
+
return sock;
|
478
|
+
}
|
479
|
+
|
480
|
+
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
481
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
482
|
+
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
483
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
484
|
+
GGML_ASSERT(status);
|
485
|
+
delete ctx;
|
486
|
+
}
|
487
|
+
|
488
|
+
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
489
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
490
|
+
if (ctx->base_ptr != nullptr) {
|
491
|
+
return ctx->base_ptr;
|
492
|
+
}
|
493
|
+
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
494
|
+
rpc_msg_buffer_get_base_rsp response;
|
495
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
496
|
+
GGML_ASSERT(status);
|
497
|
+
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
498
|
+
return ctx->base_ptr;
|
499
|
+
}
|
500
|
+
|
501
|
+
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
502
|
+
rpc_tensor result;
|
503
|
+
result.id = reinterpret_cast<uint64_t>(tensor);
|
504
|
+
result.type = tensor->type;
|
505
|
+
if (tensor->buffer) {
|
506
|
+
ggml_backend_buffer_t buffer = tensor->buffer;
|
507
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
508
|
+
result.buffer = ctx->remote_ptr;
|
509
|
+
} else {
|
510
|
+
result.buffer = 0;
|
511
|
+
}
|
512
|
+
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
513
|
+
result.ne[i] = tensor->ne[i];
|
514
|
+
result.nb[i] = tensor->nb[i];
|
515
|
+
}
|
516
|
+
result.op = tensor->op;
|
517
|
+
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
518
|
+
result.op_params[i] = tensor->op_params[i];
|
519
|
+
}
|
520
|
+
result.flags = tensor->flags;
|
521
|
+
for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
|
522
|
+
result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
|
523
|
+
}
|
524
|
+
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
525
|
+
result.view_offs = tensor->view_offs;
|
526
|
+
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
527
|
+
|
528
|
+
// Avoid sending uninitialized data over the wire
|
529
|
+
memset(result.name, 0, sizeof(result.name));
|
530
|
+
memset(result.padding, 0, sizeof(result.padding));
|
531
|
+
|
532
|
+
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
|
533
|
+
return result;
|
534
|
+
}
|
535
|
+
|
536
|
+
static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
537
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
538
|
+
|
539
|
+
// CUDA backend on the server pads everything to 512 due to CUDA limitations.
|
540
|
+
// Due to bandwidth constraints, we only call the server init tensor functions if necessary.
|
541
|
+
// In particular, only quantized tensors need padding
|
542
|
+
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
|
543
|
+
rpc_msg_init_tensor_req request;
|
544
|
+
|
545
|
+
request.tensor = serialize_tensor(tensor);
|
546
|
+
|
547
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
|
548
|
+
GGML_ASSERT(status);
|
549
|
+
}
|
550
|
+
return GGML_STATUS_SUCCESS;
|
551
|
+
}
|
552
|
+
|
553
|
+
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
554
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
555
|
+
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
556
|
+
if (size > HASH_THRESHOLD) {
|
557
|
+
rpc_msg_set_tensor_hash_req request;
|
558
|
+
request.tensor = rpc_tensor;
|
559
|
+
request.offset = offset;
|
560
|
+
request.hash = fnv_hash((const uint8_t*)data, size);
|
561
|
+
rpc_msg_set_tensor_hash_rsp response;
|
562
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
563
|
+
GGML_ASSERT(status);
|
564
|
+
if (response.result) {
|
565
|
+
// the server has the same data, no need to send it
|
566
|
+
return;
|
567
|
+
}
|
568
|
+
}
|
569
|
+
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
570
|
+
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
571
|
+
std::vector<uint8_t> input(input_size, 0);
|
572
|
+
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
573
|
+
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
574
|
+
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
575
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
|
576
|
+
GGML_ASSERT(status);
|
577
|
+
}
|
578
|
+
|
579
|
+
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
580
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
581
|
+
rpc_msg_get_tensor_req request;
|
582
|
+
request.tensor = serialize_tensor(tensor);
|
583
|
+
request.offset = offset;
|
584
|
+
request.size = size;
|
585
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
586
|
+
GGML_ASSERT(status);
|
587
|
+
}
|
588
|
+
|
589
|
+
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
590
|
+
// check if src and dst are on the same server
|
591
|
+
ggml_backend_buffer_t src_buffer = src->buffer;
|
592
|
+
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
593
|
+
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
594
|
+
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
595
|
+
if (src_ctx->sock != dst_ctx->sock) {
|
596
|
+
return false;
|
597
|
+
}
|
598
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
599
|
+
rpc_msg_copy_tensor_req request;
|
600
|
+
request.src = serialize_tensor(src);
|
601
|
+
request.dst = serialize_tensor(dst);
|
602
|
+
rpc_msg_copy_tensor_rsp response;
|
603
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
604
|
+
GGML_ASSERT(status);
|
605
|
+
return response.result;
|
606
|
+
}
|
607
|
+
|
608
|
+
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
609
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
610
|
+
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
611
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
612
|
+
GGML_ASSERT(status);
|
613
|
+
}
|
614
|
+
|
615
|
+
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
616
|
+
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
617
|
+
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
618
|
+
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
619
|
+
/* .memset_tensor = */ NULL,
|
620
|
+
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
|
621
|
+
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
|
622
|
+
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
|
623
|
+
/* .clear = */ ggml_backend_rpc_buffer_clear,
|
624
|
+
/* .reset = */ NULL,
|
625
|
+
};
|
626
|
+
|
627
|
+
static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
628
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
629
|
+
return buft_ctx->name.c_str();
|
630
|
+
}
|
631
|
+
|
632
|
+
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
633
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
634
|
+
rpc_msg_alloc_buffer_req request = {size};
|
635
|
+
rpc_msg_alloc_buffer_rsp response;
|
636
|
+
auto sock = get_socket(buft_ctx->endpoint);
|
637
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
638
|
+
GGML_ASSERT(status);
|
639
|
+
if (response.remote_ptr != 0) {
|
640
|
+
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
641
|
+
ggml_backend_rpc_buffer_interface,
|
642
|
+
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
|
643
|
+
response.remote_size);
|
644
|
+
return buffer;
|
645
|
+
} else {
|
646
|
+
return nullptr;
|
647
|
+
}
|
648
|
+
}
|
649
|
+
|
650
|
+
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
651
|
+
rpc_msg_get_alignment_rsp response;
|
652
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
653
|
+
GGML_ASSERT(status);
|
654
|
+
return response.alignment;
|
655
|
+
}
|
656
|
+
|
657
|
+
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
658
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
659
|
+
return buft_ctx->alignment;
|
660
|
+
}
|
661
|
+
|
662
|
+
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
663
|
+
rpc_msg_get_max_size_rsp response;
|
664
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
665
|
+
GGML_ASSERT(status);
|
666
|
+
return response.max_size;
|
667
|
+
}
|
668
|
+
|
669
|
+
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
670
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
671
|
+
return buft_ctx->max_size;
|
672
|
+
}
|
673
|
+
|
674
|
+
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
675
|
+
// See comments in init_tensor.
|
676
|
+
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
|
677
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
678
|
+
auto sock = get_socket(buft_ctx->endpoint);
|
679
|
+
|
680
|
+
rpc_msg_get_alloc_size_req request;
|
681
|
+
|
682
|
+
request.tensor = serialize_tensor(tensor);
|
683
|
+
|
684
|
+
rpc_msg_get_alloc_size_rsp response;
|
685
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
686
|
+
GGML_ASSERT(status);
|
687
|
+
|
688
|
+
return response.alloc_size;
|
689
|
+
} else {
|
690
|
+
return ggml_nbytes(tensor);
|
691
|
+
}
|
692
|
+
}
|
693
|
+
|
694
|
+
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
695
|
+
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
|
696
|
+
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
|
697
|
+
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
|
698
|
+
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
|
699
|
+
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
|
700
|
+
/* .is_host = */ NULL,
|
701
|
+
};
|
702
|
+
|
703
|
+
static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
704
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
705
|
+
|
706
|
+
return rpc_ctx->name.c_str();
|
707
|
+
}
|
708
|
+
|
709
|
+
static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
710
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
711
|
+
delete rpc_ctx;
|
712
|
+
delete backend;
|
713
|
+
}
|
714
|
+
|
715
|
+
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
716
|
+
GGML_UNUSED(backend);
|
717
|
+
// this is no-op because we don't have any async operations
|
718
|
+
}
|
719
|
+
|
720
|
+
static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
|
721
|
+
if (tensor == nullptr) {
|
722
|
+
return;
|
723
|
+
}
|
724
|
+
if (visited.find(tensor) != visited.end()) {
|
725
|
+
return;
|
726
|
+
}
|
727
|
+
visited.insert(tensor);
|
728
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
729
|
+
add_tensor(tensor->src[i], tensors, visited);
|
730
|
+
}
|
731
|
+
add_tensor(tensor->view_src, tensors, visited);
|
732
|
+
tensors.push_back(serialize_tensor(tensor));
|
733
|
+
}
|
734
|
+
|
735
|
+
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
736
|
+
uint32_t n_nodes = cgraph->n_nodes;
|
737
|
+
std::vector<rpc_tensor> tensors;
|
738
|
+
std::unordered_set<ggml_tensor*> visited;
|
739
|
+
for (uint32_t i = 0; i < n_nodes; i++) {
|
740
|
+
add_tensor(cgraph->nodes[i], tensors, visited);
|
741
|
+
}
|
742
|
+
// serialization format:
|
743
|
+
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
744
|
+
uint32_t n_tensors = tensors.size();
|
745
|
+
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
746
|
+
output.resize(output_size, 0);
|
747
|
+
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
|
748
|
+
for (uint32_t i = 0; i < n_nodes; i++) {
|
749
|
+
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
750
|
+
}
|
751
|
+
uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
|
752
|
+
*out_ntensors = n_tensors;
|
753
|
+
rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
|
754
|
+
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
755
|
+
}
|
756
|
+
|
757
|
+
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
758
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
759
|
+
std::vector<uint8_t> input;
|
760
|
+
serialize_graph(cgraph, input);
|
761
|
+
rpc_msg_graph_compute_rsp response;
|
762
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
763
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
764
|
+
GGML_ASSERT(status);
|
765
|
+
return (enum ggml_status)response.result;
|
766
|
+
}
|
767
|
+
|
768
|
+
static ggml_backend_i ggml_backend_rpc_interface = {
|
769
|
+
/* .get_name = */ ggml_backend_rpc_name,
|
770
|
+
/* .free = */ ggml_backend_rpc_free,
|
771
|
+
/* .set_tensor_async = */ NULL,
|
772
|
+
/* .get_tensor_async = */ NULL,
|
773
|
+
/* .cpy_tensor_async = */ NULL,
|
774
|
+
/* .synchronize = */ ggml_backend_rpc_synchronize,
|
775
|
+
/* .graph_plan_create = */ NULL,
|
776
|
+
/* .graph_plan_free = */ NULL,
|
777
|
+
/* .graph_plan_update = */ NULL,
|
778
|
+
/* .graph_plan_compute = */ NULL,
|
779
|
+
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
780
|
+
/* .event_record = */ NULL,
|
781
|
+
/* .event_wait = */ NULL,
|
782
|
+
};
|
783
|
+
|
784
|
+
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
785
|
+
static std::mutex mutex;
|
786
|
+
std::lock_guard<std::mutex> lock(mutex);
|
787
|
+
// NOTE: buffer types are allocated and never freed; this is by design
|
788
|
+
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
789
|
+
auto it = buft_map.find(endpoint);
|
790
|
+
if (it != buft_map.end()) {
|
791
|
+
return it->second;
|
792
|
+
}
|
793
|
+
auto sock = get_socket(endpoint);
|
794
|
+
if (sock == nullptr) {
|
795
|
+
fprintf(stderr, "Failed to connect to %s\n", endpoint);
|
796
|
+
return nullptr;
|
797
|
+
}
|
798
|
+
size_t alignment = get_alignment(sock);
|
799
|
+
size_t max_size = get_max_size(sock);
|
800
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
801
|
+
/* .endpoint = */ endpoint,
|
802
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
803
|
+
/* .alignment = */ alignment,
|
804
|
+
/* .max_size = */ max_size
|
805
|
+
};
|
806
|
+
|
807
|
+
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
808
|
+
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
809
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
810
|
+
/* .context = */ buft_ctx
|
811
|
+
};
|
812
|
+
buft_map[endpoint] = buft;
|
813
|
+
return buft;
|
814
|
+
}
|
815
|
+
|
816
|
+
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
817
|
+
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
818
|
+
/* .endpoint = */ endpoint,
|
819
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
820
|
+
};
|
821
|
+
|
822
|
+
ggml_backend_t backend = new ggml_backend {
|
823
|
+
/* .guid = */ ggml_backend_rpc_guid(),
|
824
|
+
/* .interface = */ ggml_backend_rpc_interface,
|
825
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
826
|
+
/* .context = */ ctx
|
827
|
+
};
|
828
|
+
return backend;
|
829
|
+
}
|
830
|
+
|
831
|
+
bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
832
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
833
|
+
}
|
834
|
+
|
835
|
+
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
836
|
+
rpc_msg_get_device_memory_rsp response;
|
837
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
838
|
+
GGML_ASSERT(status);
|
839
|
+
*free = response.free_mem;
|
840
|
+
*total = response.total_mem;
|
841
|
+
}
|
842
|
+
|
843
|
+
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
844
|
+
auto sock = get_socket(endpoint);
|
845
|
+
if (sock == nullptr) {
|
846
|
+
*free = 0;
|
847
|
+
*total = 0;
|
848
|
+
return;
|
849
|
+
}
|
850
|
+
get_device_memory(sock, free, total);
|
851
|
+
}
|
852
|
+
|
853
|
+
// RPC server-side implementation
|
854
|
+
|
855
|
+
class rpc_server {
|
856
|
+
public:
|
857
|
+
rpc_server(ggml_backend_t backend, const char * cache_dir)
|
858
|
+
: backend(backend), cache_dir(cache_dir) {
|
859
|
+
}
|
860
|
+
~rpc_server();
|
861
|
+
|
862
|
+
void hello(rpc_msg_hello_rsp & response);
|
863
|
+
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
864
|
+
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
865
|
+
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
866
|
+
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
867
|
+
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
868
|
+
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
869
|
+
bool set_tensor(const std::vector<uint8_t> & input);
|
870
|
+
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
871
|
+
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
872
|
+
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
873
|
+
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
874
|
+
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
875
|
+
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
876
|
+
|
877
|
+
private:
|
878
|
+
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
879
|
+
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
880
|
+
ggml_tensor * create_node(uint64_t id,
|
881
|
+
struct ggml_context * ctx,
|
882
|
+
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
883
|
+
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
|
884
|
+
|
885
|
+
|
886
|
+
ggml_backend_t backend;
|
887
|
+
const char * cache_dir;
|
888
|
+
std::unordered_set<ggml_backend_buffer_t> buffers;
|
889
|
+
};
|
890
|
+
|
891
|
+
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
892
|
+
response.major = RPC_PROTO_MAJOR_VERSION;
|
893
|
+
response.minor = RPC_PROTO_MINOR_VERSION;
|
894
|
+
response.patch = RPC_PROTO_PATCH_VERSION;
|
895
|
+
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
896
|
+
}
|
897
|
+
|
898
|
+
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
899
|
+
ggml_backend_buffer_type_t buft;
|
900
|
+
struct ggml_init_params params {
|
901
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
902
|
+
/*.mem_buffer =*/ NULL,
|
903
|
+
/*.no_alloc =*/ true,
|
904
|
+
};
|
905
|
+
|
906
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
907
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
908
|
+
ggml_context * ctx = ctx_ptr.get();
|
909
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
910
|
+
|
911
|
+
if (tensor == nullptr) {
|
912
|
+
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
913
|
+
return false;
|
914
|
+
}
|
915
|
+
|
916
|
+
if (tensor->buffer == nullptr) {
|
917
|
+
//No buffer allocated.
|
918
|
+
buft = ggml_backend_get_default_buffer_type(backend);
|
919
|
+
} else {
|
920
|
+
buft = tensor->buffer->buft;
|
921
|
+
}
|
922
|
+
|
923
|
+
response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
|
924
|
+
|
925
|
+
return true;
|
926
|
+
}
|
927
|
+
|
928
|
+
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
929
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
930
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
931
|
+
response.remote_ptr = 0;
|
932
|
+
response.remote_size = 0;
|
933
|
+
if (buffer != nullptr) {
|
934
|
+
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
935
|
+
response.remote_size = buffer->size;
|
936
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
937
|
+
buffers.insert(buffer);
|
938
|
+
} else {
|
939
|
+
GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
940
|
+
}
|
941
|
+
}
|
942
|
+
|
943
|
+
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
944
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
945
|
+
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
946
|
+
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
947
|
+
response.alignment = alignment;
|
948
|
+
}
|
949
|
+
|
950
|
+
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
951
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
952
|
+
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
953
|
+
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
954
|
+
response.max_size = max_size;
|
955
|
+
}
|
956
|
+
|
957
|
+
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
958
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
959
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
960
|
+
if (buffers.find(buffer) == buffers.end()) {
|
961
|
+
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
|
962
|
+
return false;
|
963
|
+
}
|
964
|
+
void * base = ggml_backend_buffer_get_base(buffer);
|
965
|
+
response.base_ptr = reinterpret_cast<uint64_t>(base);
|
966
|
+
return true;
|
967
|
+
}
|
968
|
+
|
969
|
+
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
970
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
971
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
972
|
+
if (buffers.find(buffer) == buffers.end()) {
|
973
|
+
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
|
974
|
+
return false;
|
975
|
+
}
|
976
|
+
ggml_backend_buffer_free(buffer);
|
977
|
+
buffers.erase(buffer);
|
978
|
+
return true;
|
979
|
+
}
|
980
|
+
|
981
|
+
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
982
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
983
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
984
|
+
if (buffers.find(buffer) == buffers.end()) {
|
985
|
+
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
|
986
|
+
return false;
|
987
|
+
}
|
988
|
+
ggml_backend_buffer_clear(buffer, request.value);
|
989
|
+
return true;
|
990
|
+
}
|
991
|
+
|
992
|
+
ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
|
993
|
+
// Validate tensor type before using it
|
994
|
+
if (tensor->type >= GGML_TYPE_COUNT) {
|
995
|
+
GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
|
996
|
+
return nullptr;
|
997
|
+
}
|
998
|
+
|
999
|
+
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
|
1000
|
+
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
1001
|
+
|
1002
|
+
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
|
1003
|
+
if (result == nullptr) {
|
1004
|
+
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
|
1005
|
+
return nullptr;
|
1006
|
+
}
|
1007
|
+
|
1008
|
+
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
1009
|
+
result->nb[i] = tensor->nb[i];
|
1010
|
+
}
|
1011
|
+
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
|
1012
|
+
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
|
1013
|
+
result->buffer = nullptr;
|
1014
|
+
}
|
1015
|
+
|
1016
|
+
if (result->buffer) {
|
1017
|
+
// require that the tensor data does not go beyond the buffer end
|
1018
|
+
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
|
1019
|
+
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
|
1020
|
+
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
|
1021
|
+
GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
|
1022
|
+
GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
|
1023
|
+
}
|
1024
|
+
|
1025
|
+
result->op = (ggml_op) tensor->op;
|
1026
|
+
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
1027
|
+
result->op_params[i] = tensor->op_params[i];
|
1028
|
+
}
|
1029
|
+
result->flags = tensor->flags;
|
1030
|
+
result->data = reinterpret_cast<void *>(tensor->data);
|
1031
|
+
ggml_set_name(result, tensor->name);
|
1032
|
+
return result;
|
1033
|
+
}
|
1034
|
+
|
1035
|
+
|
1036
|
+
bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
1037
|
+
// serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
|
1038
|
+
if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
|
1039
|
+
return false;
|
1040
|
+
}
|
1041
|
+
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
1042
|
+
uint64_t offset;
|
1043
|
+
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
1044
|
+
const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
|
1045
|
+
|
1046
|
+
struct ggml_init_params params {
|
1047
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
1048
|
+
/*.mem_buffer =*/ NULL,
|
1049
|
+
/*.no_alloc =*/ true,
|
1050
|
+
};
|
1051
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1052
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
1053
|
+
ggml_context * ctx = ctx_ptr.get();
|
1054
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
1055
|
+
if (tensor == nullptr) {
|
1056
|
+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
1057
|
+
return false;
|
1058
|
+
}
|
1059
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
1060
|
+
|
1061
|
+
// sanitize tensor->data
|
1062
|
+
{
|
1063
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
1064
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
1065
|
+
|
1066
|
+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
1067
|
+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
|
1068
|
+
__func__, in_tensor->data, offset, size, p0, p1);
|
1069
|
+
return false;
|
1070
|
+
}
|
1071
|
+
}
|
1072
|
+
|
1073
|
+
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
|
1074
|
+
if (cache_dir && size > HASH_THRESHOLD) {
|
1075
|
+
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
1076
|
+
char hash_str[17];
|
1077
|
+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
1078
|
+
// save to cache_dir/hash_str
|
1079
|
+
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
1080
|
+
std::ofstream ofs(cache_file, std::ios::binary);
|
1081
|
+
ofs.write((const char *)data, size);
|
1082
|
+
printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
1083
|
+
}
|
1084
|
+
ggml_backend_tensor_set(tensor, data, offset, size);
|
1085
|
+
return true;
|
1086
|
+
}
|
1087
|
+
|
1088
|
+
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
1089
|
+
if (!cache_dir) {
|
1090
|
+
return false;
|
1091
|
+
}
|
1092
|
+
char hash_str[17];
|
1093
|
+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
1094
|
+
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
1095
|
+
if (!fs::exists(cache_file)) {
|
1096
|
+
return false;
|
1097
|
+
}
|
1098
|
+
std::ifstream ifs(cache_file, std::ios::binary);
|
1099
|
+
ifs.seekg(0, std::ios::end);
|
1100
|
+
size_t size = ifs.tellg();
|
1101
|
+
ifs.seekg(0, std::ios::beg);
|
1102
|
+
data.resize(size);
|
1103
|
+
ifs.read((char *)data.data(), size);
|
1104
|
+
return true;
|
1105
|
+
}
|
1106
|
+
|
1107
|
+
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
|
1108
|
+
{
|
1109
|
+
std::vector<uint8_t> cached_file;
|
1110
|
+
if (!get_cached_file(request.hash, cached_file)) {
|
1111
|
+
response.result = 0;
|
1112
|
+
return true;
|
1113
|
+
}
|
1114
|
+
size_t size = cached_file.size();
|
1115
|
+
struct ggml_init_params params {
|
1116
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
1117
|
+
/*.mem_buffer =*/ NULL,
|
1118
|
+
/*.no_alloc =*/ true,
|
1119
|
+
};
|
1120
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1121
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
1122
|
+
ggml_context * ctx = ctx_ptr.get();
|
1123
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
1124
|
+
if (tensor == nullptr) {
|
1125
|
+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
1126
|
+
return false;
|
1127
|
+
}
|
1128
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
1129
|
+
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
|
1130
|
+
|
1131
|
+
// sanitize tensor->data
|
1132
|
+
{
|
1133
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
1134
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
1135
|
+
|
1136
|
+
if (request.tensor.data + request.offset < p0
|
1137
|
+
|| request.tensor.data + request.offset >= p1
|
1138
|
+
|| size > (p1 - request.tensor.data - request.offset)) {
|
1139
|
+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
1140
|
+
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
|
1141
|
+
return false;
|
1142
|
+
}
|
1143
|
+
}
|
1144
|
+
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
|
1145
|
+
response.result = 1;
|
1146
|
+
return true;
|
1147
|
+
}
|
1148
|
+
|
1149
|
+
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
1150
|
+
struct ggml_init_params params {
|
1151
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
1152
|
+
/*.mem_buffer =*/ NULL,
|
1153
|
+
/*.no_alloc =*/ true,
|
1154
|
+
};
|
1155
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1156
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
1157
|
+
ggml_context * ctx = ctx_ptr.get();
|
1158
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
1159
|
+
if (tensor == nullptr) {
|
1160
|
+
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
|
1161
|
+
return false;
|
1162
|
+
}
|
1163
|
+
|
1164
|
+
// Call the backend's buffer_init_tensor function
|
1165
|
+
ggml_backend_buffer_t buffer = tensor->buffer;
|
1166
|
+
if (buffer && buffer->iface.init_tensor) {
|
1167
|
+
buffer->iface.init_tensor(buffer, tensor);
|
1168
|
+
} else {
|
1169
|
+
GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
|
1170
|
+
}
|
1171
|
+
|
1172
|
+
if (tensor->extra != nullptr) {
|
1173
|
+
// This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
|
1174
|
+
// Currently unimplemented.
|
1175
|
+
GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
|
1176
|
+
return false;
|
1177
|
+
}
|
1178
|
+
|
1179
|
+
return true;
|
1180
|
+
}
|
1181
|
+
|
1182
|
+
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
|
1183
|
+
struct ggml_init_params params {
|
1184
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
1185
|
+
/*.mem_buffer =*/ NULL,
|
1186
|
+
/*.no_alloc =*/ true,
|
1187
|
+
};
|
1188
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1189
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
1190
|
+
ggml_context * ctx = ctx_ptr.get();
|
1191
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
1192
|
+
if (tensor == nullptr) {
|
1193
|
+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
1194
|
+
return false;
|
1195
|
+
}
|
1196
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
1197
|
+
|
1198
|
+
// sanitize tensor->data
|
1199
|
+
{
|
1200
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
1201
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
1202
|
+
|
1203
|
+
if (request.tensor.data + request.offset < p0 ||
|
1204
|
+
request.tensor.data + request.offset >= p1 ||
|
1205
|
+
request.size > (p1 - request.tensor.data - request.offset)) {
|
1206
|
+
GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
1207
|
+
__func__, request.tensor.data, request.offset, request.size, p0, p1);
|
1208
|
+
return false;
|
1209
|
+
}
|
1210
|
+
}
|
1211
|
+
|
1212
|
+
response.resize(request.size, 0);
|
1213
|
+
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
1214
|
+
return true;
|
1215
|
+
}
|
1216
|
+
|
1217
|
+
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
|
1218
|
+
struct ggml_init_params params {
|
1219
|
+
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
1220
|
+
/*.mem_buffer =*/ NULL,
|
1221
|
+
/*.no_alloc =*/ true,
|
1222
|
+
};
|
1223
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1224
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
1225
|
+
ggml_context * ctx = ctx_ptr.get();
|
1226
|
+
|
1227
|
+
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
1228
|
+
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
1229
|
+
if (src == nullptr || dst == nullptr) {
|
1230
|
+
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
|
1231
|
+
return false;
|
1232
|
+
}
|
1233
|
+
|
1234
|
+
uint64_t src_size = (uint64_t) ggml_nbytes(src);
|
1235
|
+
uint64_t dst_data = (uint64_t) dst->data;
|
1236
|
+
uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
|
1237
|
+
uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
|
1238
|
+
|
1239
|
+
if (dst_data + src_size > dst_base + dst_buf_sz) {
|
1240
|
+
GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
|
1241
|
+
" write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
|
1242
|
+
" buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
|
1243
|
+
__func__,
|
1244
|
+
dst_data,
|
1245
|
+
dst_data + src_size,
|
1246
|
+
dst_base,
|
1247
|
+
dst_base + dst_buf_sz);
|
1248
|
+
return false;
|
1249
|
+
}
|
1250
|
+
|
1251
|
+
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
|
1252
|
+
__func__, (void*) src->buffer, (void*) dst->buffer);
|
1253
|
+
|
1254
|
+
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
1255
|
+
return true;
|
1256
|
+
}
|
1257
|
+
|
1258
|
+
ggml_tensor * rpc_server::create_node(uint64_t id,
|
1259
|
+
struct ggml_context * ctx,
|
1260
|
+
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
1261
|
+
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
|
1262
|
+
if (tensor_map.find(id) != tensor_map.end()) {
|
1263
|
+
return tensor_map[id];
|
1264
|
+
}
|
1265
|
+
// Safely find the tensor pointer
|
1266
|
+
auto it_ptr = tensor_ptrs.find(id);
|
1267
|
+
if (it_ptr == tensor_ptrs.end()) {
|
1268
|
+
return nullptr;
|
1269
|
+
}
|
1270
|
+
const rpc_tensor * tensor = it_ptr->second;
|
1271
|
+
|
1272
|
+
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
|
1273
|
+
if (result == nullptr) {
|
1274
|
+
return nullptr;
|
1275
|
+
}
|
1276
|
+
tensor_map[id] = result;
|
1277
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
1278
|
+
// Check if the source ID is 0 before calling create_node recursively
|
1279
|
+
if (tensor->src[i] == 0) {
|
1280
|
+
result->src[i] = nullptr;
|
1281
|
+
} else {
|
1282
|
+
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
|
1283
|
+
// If the recursive call failed for a non-zero ID, propagate the error
|
1284
|
+
if (result->src[i] == nullptr) {
|
1285
|
+
GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
|
1286
|
+
__func__, i, tensor->src[i], id);
|
1287
|
+
// Must return nullptr to signal failure up the call stack
|
1288
|
+
return nullptr;
|
1289
|
+
}
|
1290
|
+
}
|
1291
|
+
}
|
1292
|
+
|
1293
|
+
// Handle view_src similarly
|
1294
|
+
if (tensor->view_src == 0) {
|
1295
|
+
result->view_src = nullptr;
|
1296
|
+
} else {
|
1297
|
+
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
1298
|
+
// If the recursive call failed for a non-zero ID, propagate the error
|
1299
|
+
if (result->view_src == nullptr) {
|
1300
|
+
GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
|
1301
|
+
__func__, tensor->view_src, id);
|
1302
|
+
// Must return nullptr to signal failure up the call stack
|
1303
|
+
return nullptr;
|
1304
|
+
}
|
1305
|
+
}
|
1306
|
+
result->view_offs = tensor->view_offs;
|
1307
|
+
return result;
|
1308
|
+
}
|
1309
|
+
|
1310
|
+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
1311
|
+
// serialization format:
|
1312
|
+
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
1313
|
+
if (input.size() < sizeof(uint32_t)) {
|
1314
|
+
return false;
|
1315
|
+
}
|
1316
|
+
uint32_t n_nodes;
|
1317
|
+
memcpy(&n_nodes, input.data(), sizeof(n_nodes));
|
1318
|
+
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
|
1319
|
+
return false;
|
1320
|
+
}
|
1321
|
+
const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
|
1322
|
+
uint32_t n_tensors;
|
1323
|
+
memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
|
1324
|
+
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
|
1325
|
+
return false;
|
1326
|
+
}
|
1327
|
+
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
|
1328
|
+
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
1329
|
+
|
1330
|
+
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
1331
|
+
|
1332
|
+
struct ggml_init_params params = {
|
1333
|
+
/*.mem_size =*/ buf_size,
|
1334
|
+
/*.mem_buffer =*/ NULL,
|
1335
|
+
/*.no_alloc =*/ true,
|
1336
|
+
};
|
1337
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1338
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
1339
|
+
ggml_context * ctx = ctx_ptr.get();
|
1340
|
+
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
1341
|
+
graph->n_nodes = n_nodes;
|
1342
|
+
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
1343
|
+
for (uint32_t i = 0; i < n_tensors; i++) {
|
1344
|
+
tensor_ptrs[tensors[i].id] = &tensors[i];
|
1345
|
+
}
|
1346
|
+
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
|
1347
|
+
for (uint32_t i = 0; i < n_nodes; i++) {
|
1348
|
+
int64_t id;
|
1349
|
+
memcpy(&id, &nodes[i], sizeof(id));
|
1350
|
+
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
1351
|
+
|
1352
|
+
// Check if create_node failed for a *non-zero* ID.
|
1353
|
+
// If id was 0, create_node returning nullptr is expected.
|
1354
|
+
// If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
|
1355
|
+
if (graph->nodes[i] == nullptr && id != 0) {
|
1356
|
+
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
|
1357
|
+
return false;
|
1358
|
+
}
|
1359
|
+
}
|
1360
|
+
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
1361
|
+
response.result = status;
|
1362
|
+
return true;
|
1363
|
+
}
|
1364
|
+
|
1365
|
+
rpc_server::~rpc_server() {
|
1366
|
+
for (auto buffer : buffers) {
|
1367
|
+
ggml_backend_buffer_free(buffer);
|
1368
|
+
}
|
1369
|
+
}
|
1370
|
+
|
1371
|
+
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
1372
|
+
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
1373
|
+
rpc_server server(backend, cache_dir);
|
1374
|
+
uint8_t cmd;
|
1375
|
+
if (!recv_data(sockfd, &cmd, 1)) {
|
1376
|
+
return;
|
1377
|
+
}
|
1378
|
+
// the first command sent by the client must be HELLO
|
1379
|
+
if (cmd != RPC_CMD_HELLO) {
|
1380
|
+
fprintf(stderr, "Expected HELLO command, update client\n");
|
1381
|
+
return;
|
1382
|
+
}
|
1383
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1384
|
+
return;
|
1385
|
+
}
|
1386
|
+
rpc_msg_hello_rsp response;
|
1387
|
+
server.hello(response);
|
1388
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1389
|
+
return;
|
1390
|
+
}
|
1391
|
+
while (true) {
|
1392
|
+
if (!recv_data(sockfd, &cmd, 1)) {
|
1393
|
+
break;
|
1394
|
+
}
|
1395
|
+
if (cmd >= RPC_CMD_COUNT) {
|
1396
|
+
// fail fast if the command is invalid
|
1397
|
+
fprintf(stderr, "Unknown command: %d\n", cmd);
|
1398
|
+
break;
|
1399
|
+
}
|
1400
|
+
switch (cmd) {
|
1401
|
+
case RPC_CMD_HELLO: {
|
1402
|
+
// HELLO command is handled above
|
1403
|
+
return;
|
1404
|
+
}
|
1405
|
+
case RPC_CMD_ALLOC_BUFFER: {
|
1406
|
+
rpc_msg_alloc_buffer_req request;
|
1407
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1408
|
+
return;
|
1409
|
+
}
|
1410
|
+
rpc_msg_alloc_buffer_rsp response;
|
1411
|
+
server.alloc_buffer(request, response);
|
1412
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1413
|
+
return;
|
1414
|
+
}
|
1415
|
+
break;
|
1416
|
+
}
|
1417
|
+
case RPC_CMD_GET_ALLOC_SIZE: {
|
1418
|
+
rpc_msg_get_alloc_size_req request;
|
1419
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1420
|
+
return;
|
1421
|
+
}
|
1422
|
+
rpc_msg_get_alloc_size_rsp response;
|
1423
|
+
if (!server.get_alloc_size(request, response)) {
|
1424
|
+
return;
|
1425
|
+
}
|
1426
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1427
|
+
return;
|
1428
|
+
}
|
1429
|
+
break;
|
1430
|
+
}
|
1431
|
+
case RPC_CMD_GET_ALIGNMENT: {
|
1432
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1433
|
+
return;
|
1434
|
+
}
|
1435
|
+
rpc_msg_get_alignment_rsp response;
|
1436
|
+
server.get_alignment(response);
|
1437
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1438
|
+
return;
|
1439
|
+
}
|
1440
|
+
break;
|
1441
|
+
}
|
1442
|
+
case RPC_CMD_GET_MAX_SIZE: {
|
1443
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1444
|
+
return;
|
1445
|
+
}
|
1446
|
+
rpc_msg_get_max_size_rsp response;
|
1447
|
+
server.get_max_size(response);
|
1448
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1449
|
+
return;
|
1450
|
+
}
|
1451
|
+
break;
|
1452
|
+
}
|
1453
|
+
case RPC_CMD_BUFFER_GET_BASE: {
|
1454
|
+
rpc_msg_buffer_get_base_req request;
|
1455
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1456
|
+
return;
|
1457
|
+
}
|
1458
|
+
rpc_msg_buffer_get_base_rsp response;
|
1459
|
+
if (!server.buffer_get_base(request, response)) {
|
1460
|
+
return;
|
1461
|
+
}
|
1462
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1463
|
+
return;
|
1464
|
+
}
|
1465
|
+
break;
|
1466
|
+
}
|
1467
|
+
case RPC_CMD_FREE_BUFFER: {
|
1468
|
+
rpc_msg_free_buffer_req request;
|
1469
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1470
|
+
return;
|
1471
|
+
}
|
1472
|
+
if (!server.free_buffer(request)) {
|
1473
|
+
return;
|
1474
|
+
}
|
1475
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
1476
|
+
return;
|
1477
|
+
}
|
1478
|
+
break;
|
1479
|
+
}
|
1480
|
+
case RPC_CMD_BUFFER_CLEAR: {
|
1481
|
+
rpc_msg_buffer_clear_req request;
|
1482
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1483
|
+
return;
|
1484
|
+
}
|
1485
|
+
if (!server.buffer_clear(request)) {
|
1486
|
+
return;
|
1487
|
+
}
|
1488
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
1489
|
+
return;
|
1490
|
+
}
|
1491
|
+
break;
|
1492
|
+
}
|
1493
|
+
case RPC_CMD_SET_TENSOR: {
|
1494
|
+
std::vector<uint8_t> input;
|
1495
|
+
if (!recv_msg(sockfd, input)) {
|
1496
|
+
return;
|
1497
|
+
}
|
1498
|
+
if (!server.set_tensor(input)) {
|
1499
|
+
return;
|
1500
|
+
}
|
1501
|
+
break;
|
1502
|
+
}
|
1503
|
+
case RPC_CMD_SET_TENSOR_HASH: {
|
1504
|
+
rpc_msg_set_tensor_hash_req request;
|
1505
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1506
|
+
return;
|
1507
|
+
}
|
1508
|
+
rpc_msg_set_tensor_hash_rsp response;
|
1509
|
+
if (!server.set_tensor_hash(request, response)) {
|
1510
|
+
return;
|
1511
|
+
}
|
1512
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1513
|
+
return;
|
1514
|
+
}
|
1515
|
+
break;
|
1516
|
+
}
|
1517
|
+
case RPC_CMD_INIT_TENSOR: {
|
1518
|
+
rpc_msg_init_tensor_req request;
|
1519
|
+
if (!recv_msg(sockfd, &request,sizeof(request))) {
|
1520
|
+
return;
|
1521
|
+
}
|
1522
|
+
if (!server.init_tensor(request)) {
|
1523
|
+
return;
|
1524
|
+
}
|
1525
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
1526
|
+
return;
|
1527
|
+
}
|
1528
|
+
break;
|
1529
|
+
}
|
1530
|
+
case RPC_CMD_GET_TENSOR: {
|
1531
|
+
rpc_msg_get_tensor_req request;
|
1532
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1533
|
+
return;
|
1534
|
+
}
|
1535
|
+
std::vector<uint8_t> response;
|
1536
|
+
if (!server.get_tensor(request, response)) {
|
1537
|
+
return;
|
1538
|
+
}
|
1539
|
+
if (!send_msg(sockfd, response.data(), response.size())) {
|
1540
|
+
return;
|
1541
|
+
}
|
1542
|
+
break;
|
1543
|
+
}
|
1544
|
+
case RPC_CMD_COPY_TENSOR: {
|
1545
|
+
rpc_msg_copy_tensor_req request;
|
1546
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1547
|
+
return;
|
1548
|
+
}
|
1549
|
+
rpc_msg_copy_tensor_rsp response;
|
1550
|
+
if (!server.copy_tensor(request, response)) {
|
1551
|
+
return;
|
1552
|
+
}
|
1553
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1554
|
+
return;
|
1555
|
+
}
|
1556
|
+
break;
|
1557
|
+
}
|
1558
|
+
case RPC_CMD_GRAPH_COMPUTE: {
|
1559
|
+
std::vector<uint8_t> input;
|
1560
|
+
if (!recv_msg(sockfd, input)) {
|
1561
|
+
return;
|
1562
|
+
}
|
1563
|
+
rpc_msg_graph_compute_rsp response;
|
1564
|
+
if (!server.graph_compute(input, response)) {
|
1565
|
+
return;
|
1566
|
+
}
|
1567
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1568
|
+
return;
|
1569
|
+
}
|
1570
|
+
break;
|
1571
|
+
}
|
1572
|
+
case RPC_CMD_GET_DEVICE_MEMORY: {
|
1573
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1574
|
+
return;
|
1575
|
+
}
|
1576
|
+
rpc_msg_get_device_memory_rsp response;
|
1577
|
+
response.free_mem = free_mem;
|
1578
|
+
response.total_mem = total_mem;
|
1579
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1580
|
+
return;
|
1581
|
+
}
|
1582
|
+
break;
|
1583
|
+
}
|
1584
|
+
default: {
|
1585
|
+
fprintf(stderr, "Unknown command: %d\n", cmd);
|
1586
|
+
return;
|
1587
|
+
}
|
1588
|
+
}
|
1589
|
+
}
|
1590
|
+
}
|
1591
|
+
|
1592
|
+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
1593
|
+
const char * cache_dir,
|
1594
|
+
size_t free_mem, size_t total_mem) {
|
1595
|
+
printf("Starting RPC server v%d.%d.%d\n",
|
1596
|
+
RPC_PROTO_MAJOR_VERSION,
|
1597
|
+
RPC_PROTO_MINOR_VERSION,
|
1598
|
+
RPC_PROTO_PATCH_VERSION);
|
1599
|
+
printf(" endpoint : %s\n", endpoint);
|
1600
|
+
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
1601
|
+
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
|
1602
|
+
|
1603
|
+
std::string host;
|
1604
|
+
int port;
|
1605
|
+
if (!parse_endpoint(endpoint, host, port)) {
|
1606
|
+
return;
|
1607
|
+
}
|
1608
|
+
#ifdef _WIN32
|
1609
|
+
{
|
1610
|
+
WSADATA wsaData;
|
1611
|
+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
1612
|
+
if (res != 0) {
|
1613
|
+
fprintf(stderr, "WSAStartup failed: %d\n", res);
|
1614
|
+
return;
|
1615
|
+
}
|
1616
|
+
}
|
1617
|
+
#endif
|
1618
|
+
auto server_socket = create_server_socket(host.c_str(), port);
|
1619
|
+
if (server_socket == nullptr) {
|
1620
|
+
fprintf(stderr, "Failed to create server socket\n");
|
1621
|
+
return;
|
1622
|
+
}
|
1623
|
+
while (true) {
|
1624
|
+
auto client_socket = socket_accept(server_socket->fd);
|
1625
|
+
if (client_socket == nullptr) {
|
1626
|
+
fprintf(stderr, "Failed to accept client connection\n");
|
1627
|
+
return;
|
1628
|
+
}
|
1629
|
+
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
1630
|
+
fflush(stdout);
|
1631
|
+
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
|
1632
|
+
printf("Client connection closed\n");
|
1633
|
+
fflush(stdout);
|
1634
|
+
}
|
1635
|
+
#ifdef _WIN32
|
1636
|
+
WSACleanup();
|
1637
|
+
#endif
|
1638
|
+
}
|
1639
|
+
|
1640
|
+
// device interface
|
1641
|
+
|
1642
|
+
struct ggml_backend_rpc_device_context {
|
1643
|
+
std::string endpoint;
|
1644
|
+
std::string name;
|
1645
|
+
};
|
1646
|
+
|
1647
|
+
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
1648
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1649
|
+
|
1650
|
+
return ctx->name.c_str();
|
1651
|
+
}
|
1652
|
+
|
1653
|
+
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
1654
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1655
|
+
|
1656
|
+
return ctx->name.c_str();
|
1657
|
+
}
|
1658
|
+
|
1659
|
+
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
1660
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1661
|
+
|
1662
|
+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
1663
|
+
|
1664
|
+
GGML_UNUSED(dev);
|
1665
|
+
}
|
1666
|
+
|
1667
|
+
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
1668
|
+
// TODO: obtain value from the server
|
1669
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
1670
|
+
|
1671
|
+
GGML_UNUSED(dev);
|
1672
|
+
}
|
1673
|
+
|
1674
|
+
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
1675
|
+
props->name = ggml_backend_rpc_device_get_name(dev);
|
1676
|
+
props->description = ggml_backend_rpc_device_get_description(dev);
|
1677
|
+
props->type = ggml_backend_rpc_device_get_type(dev);
|
1678
|
+
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
1679
|
+
props->caps = {
|
1680
|
+
/* .async = */ false,
|
1681
|
+
/* .host_buffer = */ false,
|
1682
|
+
/* .buffer_from_host_ptr = */ false,
|
1683
|
+
/* .events = */ false,
|
1684
|
+
};
|
1685
|
+
}
|
1686
|
+
|
1687
|
+
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
1688
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1689
|
+
|
1690
|
+
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
1691
|
+
|
1692
|
+
GGML_UNUSED(params);
|
1693
|
+
}
|
1694
|
+
|
1695
|
+
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
1696
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1697
|
+
|
1698
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
1699
|
+
|
1700
|
+
GGML_UNUSED(dev);
|
1701
|
+
}
|
1702
|
+
|
1703
|
+
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
1704
|
+
GGML_UNUSED(dev);
|
1705
|
+
GGML_UNUSED(op);
|
1706
|
+
//TODO: call the remote backend and cache the results
|
1707
|
+
return true;
|
1708
|
+
}
|
1709
|
+
|
1710
|
+
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
1711
|
+
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
1712
|
+
return false;
|
1713
|
+
}
|
1714
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
1715
|
+
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1716
|
+
return buft_ctx->endpoint == dev_ctx->endpoint;
|
1717
|
+
}
|
1718
|
+
|
1719
|
+
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
1720
|
+
/* .get_name = */ ggml_backend_rpc_device_get_name,
|
1721
|
+
/* .get_description = */ ggml_backend_rpc_device_get_description,
|
1722
|
+
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
|
1723
|
+
/* .get_type = */ ggml_backend_rpc_device_get_type,
|
1724
|
+
/* .get_props = */ ggml_backend_rpc_device_get_props,
|
1725
|
+
/* .init_backend = */ ggml_backend_rpc_device_init,
|
1726
|
+
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
|
1727
|
+
/* .get_host_buffer_type = */ NULL,
|
1728
|
+
/* .buffer_from_host_ptr = */ NULL,
|
1729
|
+
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
|
1730
|
+
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
|
1731
|
+
/* .offload_op = */ NULL,
|
1732
|
+
/* .event_new = */ NULL,
|
1733
|
+
/* .event_free = */ NULL,
|
1734
|
+
/* .event_synchronize = */ NULL,
|
1735
|
+
};
|
1736
|
+
|
1737
|
+
// backend reg interface
|
1738
|
+
|
1739
|
+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
1740
|
+
return "RPC";
|
1741
|
+
|
1742
|
+
GGML_UNUSED(reg);
|
1743
|
+
}
|
1744
|
+
|
1745
|
+
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
1746
|
+
return 0;
|
1747
|
+
|
1748
|
+
GGML_UNUSED(reg);
|
1749
|
+
}
|
1750
|
+
|
1751
|
+
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
1752
|
+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
|
1753
|
+
|
1754
|
+
GGML_UNUSED(reg);
|
1755
|
+
GGML_UNUSED(index);
|
1756
|
+
}
|
1757
|
+
|
1758
|
+
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
1759
|
+
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
1760
|
+
return (void *)ggml_backend_rpc_add_device;
|
1761
|
+
}
|
1762
|
+
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
|
1763
|
+
return (void *)ggml_backend_rpc_start_server;
|
1764
|
+
}
|
1765
|
+
return NULL;
|
1766
|
+
|
1767
|
+
GGML_UNUSED(reg);
|
1768
|
+
}
|
1769
|
+
|
1770
|
+
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
|
1771
|
+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
1772
|
+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
1773
|
+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
1774
|
+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
1775
|
+
};
|
1776
|
+
|
1777
|
+
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
1778
|
+
static struct ggml_backend_reg ggml_backend_rpc_reg = {
|
1779
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
1780
|
+
/* .iface = */ ggml_backend_rpc_reg_i,
|
1781
|
+
/* .context = */ NULL,
|
1782
|
+
};
|
1783
|
+
|
1784
|
+
return &ggml_backend_rpc_reg;
|
1785
|
+
}
|
1786
|
+
|
1787
|
+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
1788
|
+
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
|
1789
|
+
|
1790
|
+
static std::mutex mutex;
|
1791
|
+
std::lock_guard<std::mutex> lock(mutex);
|
1792
|
+
|
1793
|
+
if (dev_map.find(endpoint) != dev_map.end()) {
|
1794
|
+
return dev_map[endpoint];
|
1795
|
+
}
|
1796
|
+
|
1797
|
+
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
|
1798
|
+
/* .endpoint = */ endpoint,
|
1799
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
1800
|
+
};
|
1801
|
+
|
1802
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
1803
|
+
/* .iface = */ ggml_backend_rpc_device_i,
|
1804
|
+
/* .reg = */ ggml_backend_rpc_reg(),
|
1805
|
+
/* .context = */ ctx,
|
1806
|
+
};
|
1807
|
+
|
1808
|
+
dev_map[endpoint] = dev;
|
1809
|
+
|
1810
|
+
return dev;
|
1811
|
+
}
|
1812
|
+
|
1813
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
|