whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -29,9 +29,18 @@
|
|
|
29
29
|
#include <cstring>
|
|
30
30
|
#include <fstream>
|
|
31
31
|
#include <filesystem>
|
|
32
|
+
#include <algorithm>
|
|
33
|
+
|
|
34
|
+
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
|
|
35
|
+
|
|
36
|
+
#define LOG_DBG(...) \
|
|
37
|
+
do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
|
|
38
|
+
|
|
32
39
|
|
|
33
40
|
namespace fs = std::filesystem;
|
|
34
41
|
|
|
42
|
+
static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
|
|
43
|
+
|
|
35
44
|
#ifdef _WIN32
|
|
36
45
|
typedef SOCKET sockfd_t;
|
|
37
46
|
using ssize_t = __int64;
|
|
@@ -44,7 +53,7 @@ struct socket_t {
|
|
|
44
53
|
sockfd_t fd;
|
|
45
54
|
socket_t(sockfd_t fd) : fd(fd) {}
|
|
46
55
|
~socket_t() {
|
|
47
|
-
|
|
56
|
+
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
|
|
48
57
|
#ifdef _WIN32
|
|
49
58
|
closesocket(this->fd);
|
|
50
59
|
#else
|
|
@@ -96,9 +105,13 @@ enum rpc_cmd {
|
|
|
96
105
|
RPC_CMD_INIT_TENSOR,
|
|
97
106
|
RPC_CMD_GET_ALLOC_SIZE,
|
|
98
107
|
RPC_CMD_HELLO,
|
|
108
|
+
RPC_CMD_DEVICE_COUNT,
|
|
109
|
+
RPC_CMD_GRAPH_RECOMPUTE,
|
|
99
110
|
RPC_CMD_COUNT,
|
|
100
111
|
};
|
|
101
112
|
|
|
113
|
+
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
|
|
114
|
+
|
|
102
115
|
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
|
103
116
|
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
|
104
117
|
|
|
@@ -108,8 +121,14 @@ struct rpc_msg_hello_rsp {
|
|
|
108
121
|
uint8_t patch;
|
|
109
122
|
};
|
|
110
123
|
|
|
124
|
+
struct rpc_msg_device_count_rsp {
|
|
125
|
+
uint32_t device_count;
|
|
126
|
+
};
|
|
127
|
+
|
|
111
128
|
struct rpc_msg_get_alloc_size_req {
|
|
129
|
+
uint32_t device;
|
|
112
130
|
rpc_tensor tensor;
|
|
131
|
+
rpc_tensor srcs[GGML_MAX_SRC];
|
|
113
132
|
};
|
|
114
133
|
|
|
115
134
|
struct rpc_msg_get_alloc_size_rsp {
|
|
@@ -121,6 +140,7 @@ struct rpc_msg_init_tensor_req {
|
|
|
121
140
|
};
|
|
122
141
|
|
|
123
142
|
struct rpc_msg_alloc_buffer_req {
|
|
143
|
+
uint32_t device;
|
|
124
144
|
uint64_t size;
|
|
125
145
|
};
|
|
126
146
|
|
|
@@ -129,10 +149,18 @@ struct rpc_msg_alloc_buffer_rsp {
|
|
|
129
149
|
uint64_t remote_size;
|
|
130
150
|
};
|
|
131
151
|
|
|
152
|
+
struct rpc_msg_get_alignment_req {
|
|
153
|
+
uint32_t device;
|
|
154
|
+
};
|
|
155
|
+
|
|
132
156
|
struct rpc_msg_get_alignment_rsp {
|
|
133
157
|
uint64_t alignment;
|
|
134
158
|
};
|
|
135
159
|
|
|
160
|
+
struct rpc_msg_get_max_size_req {
|
|
161
|
+
uint32_t device;
|
|
162
|
+
};
|
|
163
|
+
|
|
136
164
|
struct rpc_msg_get_max_size_rsp {
|
|
137
165
|
uint64_t max_size;
|
|
138
166
|
};
|
|
@@ -179,14 +207,19 @@ struct rpc_msg_copy_tensor_rsp {
|
|
|
179
207
|
uint8_t result;
|
|
180
208
|
};
|
|
181
209
|
|
|
182
|
-
struct
|
|
183
|
-
|
|
210
|
+
struct rpc_msg_get_device_memory_req {
|
|
211
|
+
uint32_t device;
|
|
184
212
|
};
|
|
185
213
|
|
|
186
214
|
struct rpc_msg_get_device_memory_rsp {
|
|
187
215
|
uint64_t free_mem;
|
|
188
216
|
uint64_t total_mem;
|
|
189
217
|
};
|
|
218
|
+
|
|
219
|
+
struct rpc_msg_graph_recompute_req {
|
|
220
|
+
uint32_t device;
|
|
221
|
+
};
|
|
222
|
+
|
|
190
223
|
#pragma pack(pop)
|
|
191
224
|
|
|
192
225
|
// RPC data structures
|
|
@@ -198,14 +231,41 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|
|
198
231
|
|
|
199
232
|
struct ggml_backend_rpc_buffer_type_context {
|
|
200
233
|
std::string endpoint;
|
|
234
|
+
uint32_t device;
|
|
201
235
|
std::string name;
|
|
202
|
-
size_t
|
|
203
|
-
size_t
|
|
236
|
+
size_t alignment;
|
|
237
|
+
size_t max_size;
|
|
238
|
+
};
|
|
239
|
+
|
|
240
|
+
struct graph_cache {
|
|
241
|
+
|
|
242
|
+
bool is_cached(const ggml_cgraph * cgraph) {
|
|
243
|
+
if ((int)last_graph.size() != cgraph->n_nodes) {
|
|
244
|
+
return false;
|
|
245
|
+
}
|
|
246
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
247
|
+
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
|
248
|
+
return false;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
return true;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
void add(const ggml_cgraph * cgraph) {
|
|
255
|
+
last_graph.resize(cgraph->n_nodes);
|
|
256
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
257
|
+
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
std::vector<ggml_tensor> last_graph;
|
|
204
262
|
};
|
|
205
263
|
|
|
206
264
|
struct ggml_backend_rpc_context {
|
|
207
265
|
std::string endpoint;
|
|
266
|
+
uint32_t device;
|
|
208
267
|
std::string name;
|
|
268
|
+
graph_cache gc;
|
|
209
269
|
};
|
|
210
270
|
|
|
211
271
|
struct ggml_backend_rpc_buffer_context {
|
|
@@ -262,14 +322,14 @@ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
|
|
|
262
322
|
return nullptr;
|
|
263
323
|
}
|
|
264
324
|
if (!set_no_delay(sockfd)) {
|
|
265
|
-
|
|
325
|
+
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
|
|
266
326
|
return nullptr;
|
|
267
327
|
}
|
|
268
328
|
addr.sin_family = AF_INET;
|
|
269
329
|
addr.sin_port = htons(port);
|
|
270
330
|
struct hostent * server = gethostbyname(host);
|
|
271
331
|
if (server == NULL) {
|
|
272
|
-
|
|
332
|
+
GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
|
|
273
333
|
return nullptr;
|
|
274
334
|
}
|
|
275
335
|
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
|
|
@@ -286,7 +346,7 @@ static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
|
|
|
286
346
|
return nullptr;
|
|
287
347
|
}
|
|
288
348
|
if (!set_no_delay(client_socket_fd)) {
|
|
289
|
-
|
|
349
|
+
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
|
|
290
350
|
return nullptr;
|
|
291
351
|
}
|
|
292
352
|
return client_socket;
|
|
@@ -299,11 +359,11 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
|
|
|
299
359
|
return nullptr;
|
|
300
360
|
}
|
|
301
361
|
if (!set_reuse_addr(sockfd)) {
|
|
302
|
-
|
|
362
|
+
GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
|
|
303
363
|
return nullptr;
|
|
304
364
|
}
|
|
305
365
|
if (inet_addr(host) == INADDR_NONE) {
|
|
306
|
-
|
|
366
|
+
GGML_LOG_ERROR("Invalid host address: %s\n", host);
|
|
307
367
|
return nullptr;
|
|
308
368
|
}
|
|
309
369
|
struct sockaddr_in serv_addr;
|
|
@@ -323,11 +383,14 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
|
|
|
323
383
|
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
|
|
324
384
|
size_t bytes_sent = 0;
|
|
325
385
|
while (bytes_sent < size) {
|
|
326
|
-
|
|
386
|
+
size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
|
|
387
|
+
ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
|
|
327
388
|
if (n < 0) {
|
|
389
|
+
GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
|
|
390
|
+
bytes_sent, size_to_send);
|
|
328
391
|
return false;
|
|
329
392
|
}
|
|
330
|
-
bytes_sent += n;
|
|
393
|
+
bytes_sent += (size_t)n;
|
|
331
394
|
}
|
|
332
395
|
return true;
|
|
333
396
|
}
|
|
@@ -335,11 +398,18 @@ static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
|
|
|
335
398
|
static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|
336
399
|
size_t bytes_recv = 0;
|
|
337
400
|
while (bytes_recv < size) {
|
|
338
|
-
|
|
339
|
-
|
|
401
|
+
size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
|
|
402
|
+
ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
|
|
403
|
+
if (n < 0) {
|
|
404
|
+
GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
|
|
405
|
+
bytes_recv, size_to_recv);
|
|
406
|
+
return false;
|
|
407
|
+
}
|
|
408
|
+
if (n == 0) {
|
|
409
|
+
LOG_DBG("recv returned 0 (peer closed?)\n");
|
|
340
410
|
return false;
|
|
341
411
|
}
|
|
342
|
-
bytes_recv += n;
|
|
412
|
+
bytes_recv += (size_t)n;
|
|
343
413
|
}
|
|
344
414
|
return true;
|
|
345
415
|
}
|
|
@@ -370,7 +440,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
|
|
370
440
|
try {
|
|
371
441
|
input.resize(size);
|
|
372
442
|
} catch (const std::bad_alloc & e) {
|
|
373
|
-
|
|
443
|
+
GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
|
374
444
|
return false;
|
|
375
445
|
}
|
|
376
446
|
return recv_data(sockfd, input.data(), size);
|
|
@@ -430,11 +500,11 @@ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
|
|
430
500
|
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
|
431
501
|
RPC_STATUS_ASSERT(status);
|
|
432
502
|
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
|
433
|
-
|
|
503
|
+
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
434
504
|
return false;
|
|
435
505
|
}
|
|
436
506
|
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
|
437
|
-
|
|
507
|
+
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
438
508
|
}
|
|
439
509
|
return true;
|
|
440
510
|
}
|
|
@@ -454,6 +524,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
454
524
|
std::string host;
|
|
455
525
|
int port;
|
|
456
526
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
527
|
+
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
|
|
457
528
|
return nullptr;
|
|
458
529
|
}
|
|
459
530
|
#ifdef _WIN32
|
|
@@ -475,7 +546,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
475
546
|
if (!check_server_version(sock)) {
|
|
476
547
|
return nullptr;
|
|
477
548
|
}
|
|
478
|
-
|
|
549
|
+
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
|
479
550
|
sockets[endpoint] = sock;
|
|
480
551
|
return sock;
|
|
481
552
|
}
|
|
@@ -501,14 +572,23 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
501
572
|
return ctx->base_ptr;
|
|
502
573
|
}
|
|
503
574
|
|
|
575
|
+
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
|
|
576
|
+
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
|
|
577
|
+
}
|
|
578
|
+
|
|
504
579
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
505
580
|
rpc_tensor result;
|
|
581
|
+
if (!tensor) {
|
|
582
|
+
memset(&result, 0, sizeof(result));
|
|
583
|
+
return result;
|
|
584
|
+
}
|
|
585
|
+
|
|
506
586
|
result.id = reinterpret_cast<uint64_t>(tensor);
|
|
507
587
|
result.type = tensor->type;
|
|
508
|
-
if (tensor->buffer) {
|
|
588
|
+
if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
|
|
509
589
|
ggml_backend_buffer_t buffer = tensor->buffer;
|
|
510
590
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
511
|
-
result.buffer = ctx->remote_ptr;
|
|
591
|
+
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
|
|
512
592
|
} else {
|
|
513
593
|
result.buffer = 0;
|
|
514
594
|
}
|
|
@@ -590,22 +670,25 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
|
|
590
670
|
}
|
|
591
671
|
|
|
592
672
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
673
|
+
if (ggml_backend_buffer_is_rpc(src->buffer)) {
|
|
674
|
+
// check if src and dst are on the same server
|
|
675
|
+
ggml_backend_buffer_t src_buffer = src->buffer;
|
|
676
|
+
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
|
677
|
+
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
|
678
|
+
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
|
679
|
+
if (src_ctx->sock != dst_ctx->sock) {
|
|
680
|
+
return false;
|
|
681
|
+
}
|
|
682
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
683
|
+
rpc_msg_copy_tensor_req request;
|
|
684
|
+
request.src = serialize_tensor(src);
|
|
685
|
+
request.dst = serialize_tensor(dst);
|
|
686
|
+
rpc_msg_copy_tensor_rsp response;
|
|
687
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
688
|
+
RPC_STATUS_ASSERT(status);
|
|
689
|
+
return response.result;
|
|
600
690
|
}
|
|
601
|
-
|
|
602
|
-
rpc_msg_copy_tensor_req request;
|
|
603
|
-
request.src = serialize_tensor(src);
|
|
604
|
-
request.dst = serialize_tensor(dst);
|
|
605
|
-
rpc_msg_copy_tensor_rsp response;
|
|
606
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
607
|
-
RPC_STATUS_ASSERT(status);
|
|
608
|
-
return response.result;
|
|
691
|
+
return false;
|
|
609
692
|
}
|
|
610
693
|
|
|
611
694
|
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
@@ -634,7 +717,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
|
|
634
717
|
|
|
635
718
|
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
636
719
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
637
|
-
rpc_msg_alloc_buffer_req request = {size};
|
|
720
|
+
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
|
|
638
721
|
rpc_msg_alloc_buffer_rsp response;
|
|
639
722
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
640
723
|
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
|
@@ -650,9 +733,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
|
650
733
|
}
|
|
651
734
|
}
|
|
652
735
|
|
|
653
|
-
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|
736
|
+
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
|
737
|
+
rpc_msg_get_alignment_req request = {device};
|
|
654
738
|
rpc_msg_get_alignment_rsp response;
|
|
655
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT,
|
|
739
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
|
|
656
740
|
RPC_STATUS_ASSERT(status);
|
|
657
741
|
return response.alignment;
|
|
658
742
|
}
|
|
@@ -662,9 +746,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|
|
662
746
|
return buft_ctx->alignment;
|
|
663
747
|
}
|
|
664
748
|
|
|
665
|
-
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|
749
|
+
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
|
750
|
+
rpc_msg_get_max_size_req request = {device};
|
|
666
751
|
rpc_msg_get_max_size_rsp response;
|
|
667
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE,
|
|
752
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
668
753
|
RPC_STATUS_ASSERT(status);
|
|
669
754
|
return response.max_size;
|
|
670
755
|
}
|
|
@@ -675,23 +760,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
|
675
760
|
}
|
|
676
761
|
|
|
677
762
|
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
|
763
|
+
// should we query the remote server for the actual size
|
|
764
|
+
bool rpc_get = false;
|
|
765
|
+
|
|
678
766
|
// See comments in init_tensor.
|
|
679
|
-
|
|
767
|
+
rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
|
|
768
|
+
|
|
769
|
+
// ops that require additional memory for fleeting data on certain backends
|
|
770
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
|
771
|
+
rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
|
|
772
|
+
rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
|
|
773
|
+
|
|
774
|
+
if (rpc_get) {
|
|
680
775
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
681
776
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
682
777
|
|
|
683
|
-
rpc_msg_get_alloc_size_req request
|
|
778
|
+
rpc_msg_get_alloc_size_req request = {
|
|
779
|
+
/*.device =*/ buft_ctx->device,
|
|
780
|
+
/*.tensor =*/ serialize_tensor(tensor),
|
|
781
|
+
/*.srcs =*/ {},
|
|
782
|
+
};
|
|
684
783
|
|
|
685
|
-
|
|
784
|
+
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
|
|
785
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
786
|
+
request.srcs[i] = serialize_tensor(tensor->src[i]);
|
|
787
|
+
}
|
|
686
788
|
|
|
789
|
+
// TODO: cache the alloc responses to avoid extra RPC calls?
|
|
687
790
|
rpc_msg_get_alloc_size_rsp response;
|
|
688
791
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
689
792
|
RPC_STATUS_ASSERT(status);
|
|
690
793
|
|
|
691
794
|
return response.alloc_size;
|
|
692
|
-
} else {
|
|
693
|
-
return ggml_nbytes(tensor);
|
|
694
795
|
}
|
|
796
|
+
|
|
797
|
+
return ggml_nbytes(tensor);
|
|
695
798
|
}
|
|
696
799
|
|
|
697
800
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
@@ -735,7 +838,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
|
|
|
735
838
|
tensors.push_back(serialize_tensor(tensor));
|
|
736
839
|
}
|
|
737
840
|
|
|
738
|
-
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
|
841
|
+
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
|
739
842
|
uint32_t n_nodes = cgraph->n_nodes;
|
|
740
843
|
std::vector<rpc_tensor> tensors;
|
|
741
844
|
std::unordered_set<ggml_tensor*> visited;
|
|
@@ -743,29 +846,45 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
|
|
|
743
846
|
add_tensor(cgraph->nodes[i], tensors, visited);
|
|
744
847
|
}
|
|
745
848
|
// serialization format:
|
|
746
|
-
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
849
|
+
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
747
850
|
uint32_t n_tensors = tensors.size();
|
|
748
|
-
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
|
851
|
+
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
|
749
852
|
output.resize(output_size, 0);
|
|
750
|
-
|
|
853
|
+
uint8_t * dest = output.data();
|
|
854
|
+
memcpy(dest, &device, sizeof(device));
|
|
855
|
+
dest += sizeof(device);
|
|
856
|
+
memcpy(dest, &n_nodes, sizeof(n_nodes));
|
|
857
|
+
dest += sizeof(n_nodes);
|
|
751
858
|
for (uint32_t i = 0; i < n_nodes; i++) {
|
|
752
|
-
memcpy(
|
|
859
|
+
memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
|
753
860
|
}
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
861
|
+
dest += n_nodes * sizeof(uint64_t);
|
|
862
|
+
memcpy(dest, &n_tensors, sizeof(n_tensors));
|
|
863
|
+
dest += sizeof(n_tensors);
|
|
864
|
+
rpc_tensor * out_tensors = (rpc_tensor *)dest;
|
|
757
865
|
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
|
758
866
|
}
|
|
759
867
|
|
|
760
868
|
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
761
869
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
870
|
+
|
|
871
|
+
GGML_ASSERT(cgraph->n_nodes > 0);
|
|
872
|
+
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
|
873
|
+
if (reuse) {
|
|
874
|
+
rpc_msg_graph_recompute_req request;
|
|
875
|
+
request.device = rpc_ctx->device;
|
|
876
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
|
877
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
|
878
|
+
RPC_STATUS_ASSERT(status);
|
|
879
|
+
} else {
|
|
880
|
+
rpc_ctx->gc.add(cgraph);
|
|
881
|
+
std::vector<uint8_t> input;
|
|
882
|
+
serialize_graph(rpc_ctx->device, cgraph, input);
|
|
883
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
|
884
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
|
|
885
|
+
RPC_STATUS_ASSERT(status);
|
|
886
|
+
}
|
|
887
|
+
return GGML_STATUS_SUCCESS;
|
|
769
888
|
}
|
|
770
889
|
|
|
771
890
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
@@ -782,51 +901,57 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
782
901
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
|
783
902
|
/* .event_record = */ NULL,
|
|
784
903
|
/* .event_wait = */ NULL,
|
|
904
|
+
/* .graph_optimize = */ NULL,
|
|
785
905
|
};
|
|
786
906
|
|
|
787
|
-
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
907
|
+
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
|
|
788
908
|
static std::mutex mutex;
|
|
789
909
|
std::lock_guard<std::mutex> lock(mutex);
|
|
910
|
+
std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
|
790
911
|
// NOTE: buffer types are allocated and never freed; this is by design
|
|
791
912
|
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
|
792
|
-
auto it = buft_map.find(
|
|
913
|
+
auto it = buft_map.find(buft_name);
|
|
793
914
|
if (it != buft_map.end()) {
|
|
794
915
|
return it->second;
|
|
795
916
|
}
|
|
796
917
|
auto sock = get_socket(endpoint);
|
|
797
918
|
if (sock == nullptr) {
|
|
798
|
-
|
|
919
|
+
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
|
799
920
|
return nullptr;
|
|
800
921
|
}
|
|
801
|
-
size_t alignment = get_alignment(sock);
|
|
802
|
-
size_t max_size = get_max_size(sock);
|
|
922
|
+
size_t alignment = get_alignment(sock, device);
|
|
923
|
+
size_t max_size = get_max_size(sock, device);
|
|
803
924
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
|
804
925
|
/* .endpoint = */ endpoint,
|
|
805
|
-
/* .
|
|
926
|
+
/* .device = */ device,
|
|
927
|
+
/* .name = */ buft_name,
|
|
806
928
|
/* .alignment = */ alignment,
|
|
807
929
|
/* .max_size = */ max_size
|
|
808
930
|
};
|
|
809
|
-
|
|
931
|
+
auto reg = ggml_backend_rpc_add_server(endpoint);
|
|
810
932
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
811
933
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
812
|
-
/* .device = */
|
|
934
|
+
/* .device = */ ggml_backend_reg_dev_get(reg, device),
|
|
813
935
|
/* .context = */ buft_ctx
|
|
814
936
|
};
|
|
815
|
-
buft_map[
|
|
937
|
+
buft_map[buft_name] = buft;
|
|
816
938
|
return buft;
|
|
817
939
|
}
|
|
818
940
|
|
|
819
|
-
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
941
|
+
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
|
942
|
+
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
|
820
943
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
821
|
-
/* .endpoint
|
|
822
|
-
/* .
|
|
944
|
+
/* .endpoint = */ endpoint,
|
|
945
|
+
/* .device = */ device,
|
|
946
|
+
/* .name = */ dev_name,
|
|
947
|
+
/* .gc = */ {},
|
|
823
948
|
};
|
|
824
|
-
|
|
949
|
+
auto reg = ggml_backend_rpc_add_server(endpoint);
|
|
825
950
|
ggml_backend_t backend = new ggml_backend {
|
|
826
|
-
/* .guid
|
|
827
|
-
/* .
|
|
828
|
-
/* .device
|
|
829
|
-
/* .context
|
|
951
|
+
/* .guid = */ ggml_backend_rpc_guid(),
|
|
952
|
+
/* .iface = */ ggml_backend_rpc_interface,
|
|
953
|
+
/* .device = */ ggml_backend_reg_dev_get(reg, device),
|
|
954
|
+
/* .context = */ ctx
|
|
830
955
|
};
|
|
831
956
|
return backend;
|
|
832
957
|
}
|
|
@@ -835,37 +960,40 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
|
835
960
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
|
836
961
|
}
|
|
837
962
|
|
|
838
|
-
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
|
963
|
+
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
|
|
964
|
+
rpc_msg_get_device_memory_req request;
|
|
965
|
+
request.device = device;
|
|
839
966
|
rpc_msg_get_device_memory_rsp response;
|
|
840
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY,
|
|
967
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
|
|
841
968
|
RPC_STATUS_ASSERT(status);
|
|
842
969
|
*free = response.free_mem;
|
|
843
970
|
*total = response.total_mem;
|
|
844
971
|
}
|
|
845
972
|
|
|
846
|
-
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
|
973
|
+
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
|
|
847
974
|
auto sock = get_socket(endpoint);
|
|
848
975
|
if (sock == nullptr) {
|
|
849
976
|
*free = 0;
|
|
850
977
|
*total = 0;
|
|
851
978
|
return;
|
|
852
979
|
}
|
|
853
|
-
get_device_memory(sock, free, total);
|
|
980
|
+
get_device_memory(sock, device, free, total);
|
|
854
981
|
}
|
|
855
982
|
|
|
856
983
|
// RPC server-side implementation
|
|
857
984
|
|
|
858
985
|
class rpc_server {
|
|
859
986
|
public:
|
|
860
|
-
rpc_server(ggml_backend_t
|
|
861
|
-
:
|
|
987
|
+
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
|
|
988
|
+
: backends(std::move(all_backends)), cache_dir(cache_dir) {
|
|
989
|
+
stored_graphs.resize(backends.size());
|
|
862
990
|
}
|
|
863
991
|
~rpc_server();
|
|
864
992
|
|
|
865
993
|
void hello(rpc_msg_hello_rsp & response);
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
994
|
+
bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
995
|
+
bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
|
|
996
|
+
bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
|
|
869
997
|
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
|
870
998
|
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
871
999
|
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
@@ -873,9 +1001,16 @@ public:
|
|
|
873
1001
|
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
|
874
1002
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
875
1003
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
876
|
-
bool graph_compute(const std::vector<uint8_t> & input
|
|
1004
|
+
bool graph_compute(const std::vector<uint8_t> & input);
|
|
1005
|
+
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
|
|
877
1006
|
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
|
878
1007
|
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
|
1008
|
+
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
|
1009
|
+
|
|
1010
|
+
struct stored_graph {
|
|
1011
|
+
ggml_context_ptr ctx_ptr;
|
|
1012
|
+
ggml_cgraph * graph;
|
|
1013
|
+
};
|
|
879
1014
|
|
|
880
1015
|
private:
|
|
881
1016
|
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
|
@@ -886,22 +1021,28 @@ private:
|
|
|
886
1021
|
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
|
|
887
1022
|
|
|
888
1023
|
|
|
889
|
-
ggml_backend_t
|
|
1024
|
+
std::vector<ggml_backend_t> backends;
|
|
890
1025
|
const char * cache_dir;
|
|
891
1026
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
1027
|
+
// store the last computed graph for each backend
|
|
1028
|
+
std::vector<stored_graph> stored_graphs;
|
|
892
1029
|
};
|
|
893
1030
|
|
|
894
1031
|
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
|
895
1032
|
response.major = RPC_PROTO_MAJOR_VERSION;
|
|
896
1033
|
response.minor = RPC_PROTO_MINOR_VERSION;
|
|
897
1034
|
response.patch = RPC_PROTO_PATCH_VERSION;
|
|
898
|
-
|
|
1035
|
+
LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
|
899
1036
|
}
|
|
900
1037
|
|
|
901
1038
|
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
|
1039
|
+
uint32_t dev_id = request.device;
|
|
1040
|
+
if (dev_id >= backends.size()) {
|
|
1041
|
+
return false;
|
|
1042
|
+
}
|
|
902
1043
|
ggml_backend_buffer_type_t buft;
|
|
903
1044
|
struct ggml_init_params params {
|
|
904
|
-
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
1045
|
+
/*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
|
|
905
1046
|
/*.mem_buffer =*/ NULL,
|
|
906
1047
|
/*.no_alloc =*/ true,
|
|
907
1048
|
};
|
|
@@ -909,56 +1050,78 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
909
1050
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
910
1051
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
911
1052
|
ggml_context * ctx = ctx_ptr.get();
|
|
912
|
-
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
913
1053
|
|
|
1054
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
914
1055
|
if (tensor == nullptr) {
|
|
915
1056
|
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
|
916
1057
|
return false;
|
|
917
1058
|
}
|
|
1059
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
1060
|
+
if (request.srcs[i].id != 0) {
|
|
1061
|
+
tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
918
1064
|
|
|
1065
|
+
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
|
|
919
1066
|
if (tensor->buffer == nullptr) {
|
|
920
1067
|
//No buffer allocated.
|
|
921
|
-
buft = ggml_backend_get_default_buffer_type(
|
|
1068
|
+
buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
922
1069
|
} else {
|
|
923
1070
|
buft = tensor->buffer->buft;
|
|
924
1071
|
}
|
|
925
1072
|
|
|
926
|
-
response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
|
|
1073
|
+
response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);
|
|
927
1074
|
|
|
928
1075
|
return true;
|
|
929
1076
|
}
|
|
930
1077
|
|
|
931
|
-
|
|
932
|
-
|
|
1078
|
+
bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
|
1079
|
+
uint32_t dev_id = request.device;
|
|
1080
|
+
if (dev_id >= backends.size()) {
|
|
1081
|
+
return false;
|
|
1082
|
+
}
|
|
1083
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
933
1084
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
|
934
1085
|
response.remote_ptr = 0;
|
|
935
1086
|
response.remote_size = 0;
|
|
936
1087
|
if (buffer != nullptr) {
|
|
937
1088
|
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
|
938
1089
|
response.remote_size = buffer->size;
|
|
939
|
-
|
|
1090
|
+
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
|
|
1091
|
+
__func__, dev_id, request.size, response.remote_ptr, response.remote_size);
|
|
940
1092
|
buffers.insert(buffer);
|
|
941
1093
|
} else {
|
|
942
|
-
|
|
1094
|
+
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
|
|
943
1095
|
}
|
|
1096
|
+
return true;
|
|
944
1097
|
}
|
|
945
1098
|
|
|
946
|
-
|
|
947
|
-
|
|
1099
|
+
bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
|
|
1100
|
+
uint32_t dev_id = request.device;
|
|
1101
|
+
if (dev_id >= backends.size()) {
|
|
1102
|
+
return false;
|
|
1103
|
+
}
|
|
1104
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
948
1105
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
|
949
|
-
|
|
1106
|
+
LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
|
|
950
1107
|
response.alignment = alignment;
|
|
1108
|
+
return true;
|
|
951
1109
|
}
|
|
952
1110
|
|
|
953
|
-
|
|
954
|
-
|
|
1111
|
+
bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
|
|
1112
|
+
uint32_t dev_id = request.device;
|
|
1113
|
+
if (dev_id >= backends.size()) {
|
|
1114
|
+
return false;
|
|
1115
|
+
}
|
|
1116
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
955
1117
|
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
|
956
|
-
|
|
1118
|
+
LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
|
|
957
1119
|
response.max_size = max_size;
|
|
1120
|
+
return true;
|
|
958
1121
|
}
|
|
959
1122
|
|
|
960
1123
|
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
|
961
|
-
|
|
1124
|
+
LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
|
962
1125
|
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
963
1126
|
if (buffers.find(buffer) == buffers.end()) {
|
|
964
1127
|
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
|
|
@@ -970,7 +1133,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp
|
|
|
970
1133
|
}
|
|
971
1134
|
|
|
972
1135
|
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
|
973
|
-
|
|
1136
|
+
LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
|
974
1137
|
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
975
1138
|
if (buffers.find(buffer) == buffers.end()) {
|
|
976
1139
|
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
|
|
@@ -982,7 +1145,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
|
|
982
1145
|
}
|
|
983
1146
|
|
|
984
1147
|
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
|
985
|
-
|
|
1148
|
+
LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
|
986
1149
|
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
987
1150
|
if (buffers.find(buffer) == buffers.end()) {
|
|
988
1151
|
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
|
|
@@ -1055,11 +1218,11 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
1055
1218
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1056
1219
|
ggml_context * ctx = ctx_ptr.get();
|
|
1057
1220
|
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
|
1058
|
-
if (tensor == nullptr) {
|
|
1221
|
+
if (tensor == nullptr || tensor->buffer == nullptr) {
|
|
1059
1222
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1060
1223
|
return false;
|
|
1061
1224
|
}
|
|
1062
|
-
|
|
1225
|
+
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
|
1063
1226
|
|
|
1064
1227
|
// sanitize tensor->data
|
|
1065
1228
|
{
|
|
@@ -1082,7 +1245,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
1082
1245
|
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1083
1246
|
std::ofstream ofs(cache_file, std::ios::binary);
|
|
1084
1247
|
ofs.write((const char *)data, size);
|
|
1085
|
-
|
|
1248
|
+
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
|
1086
1249
|
}
|
|
1087
1250
|
ggml_backend_tensor_set(tensor, data, offset, size);
|
|
1088
1251
|
return true;
|
|
@@ -1095,7 +1258,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|
|
1095
1258
|
char hash_str[17];
|
|
1096
1259
|
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
|
1097
1260
|
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1098
|
-
|
|
1261
|
+
std::error_code ec;
|
|
1262
|
+
if (!fs::exists(cache_file, ec)) {
|
|
1099
1263
|
return false;
|
|
1100
1264
|
}
|
|
1101
1265
|
std::ifstream ifs(cache_file, std::ios::binary);
|
|
@@ -1124,12 +1288,12 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
|
|
|
1124
1288
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1125
1289
|
ggml_context * ctx = ctx_ptr.get();
|
|
1126
1290
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
1127
|
-
if (tensor == nullptr) {
|
|
1291
|
+
if (tensor == nullptr || tensor->buffer == nullptr) {
|
|
1128
1292
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1129
1293
|
return false;
|
|
1130
1294
|
}
|
|
1131
|
-
|
|
1132
|
-
|
|
1295
|
+
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
|
1296
|
+
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
|
|
1133
1297
|
|
|
1134
1298
|
// sanitize tensor->data
|
|
1135
1299
|
{
|
|
@@ -1163,7 +1327,7 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
|
|
1163
1327
|
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
|
|
1164
1328
|
return false;
|
|
1165
1329
|
}
|
|
1166
|
-
|
|
1330
|
+
LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
|
|
1167
1331
|
// Call the backend's buffer_init_tensor function
|
|
1168
1332
|
ggml_backend_buffer_t buffer = tensor->buffer;
|
|
1169
1333
|
if (buffer && buffer->iface.init_tensor) {
|
|
@@ -1192,11 +1356,11 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
|
|
|
1192
1356
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1193
1357
|
ggml_context * ctx = ctx_ptr.get();
|
|
1194
1358
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
1195
|
-
if (tensor == nullptr) {
|
|
1359
|
+
if (tensor == nullptr || tensor->buffer == nullptr) {
|
|
1196
1360
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1197
1361
|
return false;
|
|
1198
1362
|
}
|
|
1199
|
-
|
|
1363
|
+
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
|
1200
1364
|
|
|
1201
1365
|
// sanitize tensor->data
|
|
1202
1366
|
{
|
|
@@ -1229,7 +1393,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1229
1393
|
|
|
1230
1394
|
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
|
1231
1395
|
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
|
1232
|
-
if (src == nullptr || dst == nullptr) {
|
|
1396
|
+
if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
|
|
1233
1397
|
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
|
|
1234
1398
|
return false;
|
|
1235
1399
|
}
|
|
@@ -1240,7 +1404,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1240
1404
|
uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
|
|
1241
1405
|
|
|
1242
1406
|
if (dst_data + src_size > dst_base + dst_buf_sz) {
|
|
1243
|
-
|
|
1407
|
+
GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
|
|
1244
1408
|
" write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
|
|
1245
1409
|
" buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
|
|
1246
1410
|
__func__,
|
|
@@ -1251,8 +1415,8 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1251
1415
|
return false;
|
|
1252
1416
|
}
|
|
1253
1417
|
|
|
1254
|
-
|
|
1255
|
-
|
|
1418
|
+
LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
|
|
1419
|
+
__func__, (void*) src->buffer, (void*) dst->buffer);
|
|
1256
1420
|
|
|
1257
1421
|
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
|
1258
1422
|
return true;
|
|
@@ -1310,25 +1474,35 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|
|
1310
1474
|
return result;
|
|
1311
1475
|
}
|
|
1312
1476
|
|
|
1313
|
-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input
|
|
1477
|
+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
|
1314
1478
|
// serialization format:
|
|
1315
|
-
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
1316
|
-
if (input.size() < sizeof(uint32_t)) {
|
|
1479
|
+
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
1480
|
+
if (input.size() < 2*sizeof(uint32_t)) {
|
|
1481
|
+
return false;
|
|
1482
|
+
}
|
|
1483
|
+
const uint8_t * src = input.data();
|
|
1484
|
+
uint32_t device;
|
|
1485
|
+
memcpy(&device, src, sizeof(device));
|
|
1486
|
+
src += sizeof(device);
|
|
1487
|
+
if (device >= backends.size()) {
|
|
1317
1488
|
return false;
|
|
1318
1489
|
}
|
|
1319
1490
|
uint32_t n_nodes;
|
|
1320
|
-
memcpy(&n_nodes,
|
|
1321
|
-
|
|
1491
|
+
memcpy(&n_nodes, src, sizeof(n_nodes));
|
|
1492
|
+
src += sizeof(n_nodes);
|
|
1493
|
+
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
|
|
1322
1494
|
return false;
|
|
1323
1495
|
}
|
|
1324
|
-
const uint64_t * nodes = (const uint64_t *)
|
|
1496
|
+
const uint64_t * nodes = (const uint64_t *)src;
|
|
1497
|
+
src += n_nodes*sizeof(uint64_t);
|
|
1325
1498
|
uint32_t n_tensors;
|
|
1326
|
-
memcpy(&n_tensors,
|
|
1327
|
-
|
|
1499
|
+
memcpy(&n_tensors, src, sizeof(n_tensors));
|
|
1500
|
+
src += sizeof(n_tensors);
|
|
1501
|
+
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
|
|
1328
1502
|
return false;
|
|
1329
1503
|
}
|
|
1330
|
-
const rpc_tensor * tensors = (const rpc_tensor *)
|
|
1331
|
-
|
|
1504
|
+
const rpc_tensor * tensors = (const rpc_tensor *)src;
|
|
1505
|
+
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
|
1332
1506
|
|
|
1333
1507
|
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
|
1334
1508
|
|
|
@@ -1343,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1343
1517
|
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
|
1344
1518
|
graph->n_nodes = n_nodes;
|
|
1345
1519
|
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
|
1520
|
+
tensor_ptrs.reserve(n_tensors);
|
|
1346
1521
|
for (uint32_t i = 0; i < n_tensors; i++) {
|
|
1347
|
-
tensor_ptrs
|
|
1522
|
+
tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
|
|
1348
1523
|
}
|
|
1349
1524
|
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
|
|
1525
|
+
tensor_map.reserve(n_nodes);
|
|
1350
1526
|
for (uint32_t i = 0; i < n_nodes; i++) {
|
|
1351
1527
|
int64_t id;
|
|
1352
1528
|
memcpy(&id, &nodes[i], sizeof(id));
|
|
@@ -1360,8 +1536,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1360
1536
|
return false;
|
|
1361
1537
|
}
|
|
1362
1538
|
}
|
|
1363
|
-
ggml_status status = ggml_backend_graph_compute(
|
|
1364
|
-
|
|
1539
|
+
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
|
1540
|
+
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
|
1541
|
+
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
|
1542
|
+
stored_graphs[device].graph = graph;
|
|
1543
|
+
return true;
|
|
1544
|
+
}
|
|
1545
|
+
|
|
1546
|
+
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
|
|
1547
|
+
uint32_t device = request.device;
|
|
1548
|
+
if (device >= backends.size()) {
|
|
1549
|
+
return false;
|
|
1550
|
+
}
|
|
1551
|
+
if (stored_graphs[device].graph == nullptr) {
|
|
1552
|
+
return false;
|
|
1553
|
+
}
|
|
1554
|
+
ggml_cgraph * graph = stored_graphs[device].graph;
|
|
1555
|
+
LOG_DBG("[%s] device: %u\n", __func__, device);
|
|
1556
|
+
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
|
1557
|
+
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
|
1558
|
+
return true;
|
|
1559
|
+
}
|
|
1560
|
+
|
|
1561
|
+
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
|
|
1562
|
+
uint32_t dev_id = request.device;
|
|
1563
|
+
if (dev_id >= backends.size()) {
|
|
1564
|
+
return false;
|
|
1565
|
+
}
|
|
1566
|
+
size_t free, total;
|
|
1567
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
|
|
1568
|
+
ggml_backend_dev_memory(dev, &free, &total);
|
|
1569
|
+
response.free_mem = free;
|
|
1570
|
+
response.total_mem = total;
|
|
1571
|
+
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
|
|
1365
1572
|
return true;
|
|
1366
1573
|
}
|
|
1367
1574
|
|
|
@@ -1371,16 +1578,16 @@ rpc_server::~rpc_server() {
|
|
|
1371
1578
|
}
|
|
1372
1579
|
}
|
|
1373
1580
|
|
|
1374
|
-
static void rpc_serve_client(ggml_backend_t
|
|
1375
|
-
sockfd_t sockfd
|
|
1376
|
-
rpc_server server(
|
|
1581
|
+
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
|
1582
|
+
sockfd_t sockfd) {
|
|
1583
|
+
rpc_server server(backends, cache_dir);
|
|
1377
1584
|
uint8_t cmd;
|
|
1378
1585
|
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1379
1586
|
return;
|
|
1380
1587
|
}
|
|
1381
1588
|
// the first command sent by the client must be HELLO
|
|
1382
1589
|
if (cmd != RPC_CMD_HELLO) {
|
|
1383
|
-
|
|
1590
|
+
GGML_LOG_ERROR("Expected HELLO command, update client\n");
|
|
1384
1591
|
return;
|
|
1385
1592
|
}
|
|
1386
1593
|
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
@@ -1397,7 +1604,7 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1397
1604
|
}
|
|
1398
1605
|
if (cmd >= RPC_CMD_COUNT) {
|
|
1399
1606
|
// fail fast if the command is invalid
|
|
1400
|
-
|
|
1607
|
+
GGML_LOG_ERROR("Unknown command: %d\n", cmd);
|
|
1401
1608
|
break;
|
|
1402
1609
|
}
|
|
1403
1610
|
switch (cmd) {
|
|
@@ -1405,13 +1612,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1405
1612
|
// HELLO command is handled above
|
|
1406
1613
|
return;
|
|
1407
1614
|
}
|
|
1615
|
+
case RPC_CMD_DEVICE_COUNT: {
|
|
1616
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1617
|
+
return;
|
|
1618
|
+
}
|
|
1619
|
+
rpc_msg_device_count_rsp response;
|
|
1620
|
+
response.device_count = backends.size();
|
|
1621
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1622
|
+
return;
|
|
1623
|
+
}
|
|
1624
|
+
break;
|
|
1625
|
+
}
|
|
1408
1626
|
case RPC_CMD_ALLOC_BUFFER: {
|
|
1409
1627
|
rpc_msg_alloc_buffer_req request;
|
|
1410
1628
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1411
1629
|
return;
|
|
1412
1630
|
}
|
|
1413
1631
|
rpc_msg_alloc_buffer_rsp response;
|
|
1414
|
-
server.alloc_buffer(request, response)
|
|
1632
|
+
if (!server.alloc_buffer(request, response)) {
|
|
1633
|
+
return;
|
|
1634
|
+
}
|
|
1415
1635
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1416
1636
|
return;
|
|
1417
1637
|
}
|
|
@@ -1432,22 +1652,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1432
1652
|
break;
|
|
1433
1653
|
}
|
|
1434
1654
|
case RPC_CMD_GET_ALIGNMENT: {
|
|
1435
|
-
|
|
1655
|
+
rpc_msg_get_alignment_req request;
|
|
1656
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1436
1657
|
return;
|
|
1437
1658
|
}
|
|
1438
1659
|
rpc_msg_get_alignment_rsp response;
|
|
1439
|
-
server.get_alignment(response)
|
|
1660
|
+
if (!server.get_alignment(request, response)) {
|
|
1661
|
+
return;
|
|
1662
|
+
}
|
|
1440
1663
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1441
1664
|
return;
|
|
1442
1665
|
}
|
|
1443
1666
|
break;
|
|
1444
1667
|
}
|
|
1445
1668
|
case RPC_CMD_GET_MAX_SIZE: {
|
|
1446
|
-
|
|
1669
|
+
rpc_msg_get_max_size_req request;
|
|
1670
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1447
1671
|
return;
|
|
1448
1672
|
}
|
|
1449
1673
|
rpc_msg_get_max_size_rsp response;
|
|
1450
|
-
server.get_max_size(response)
|
|
1674
|
+
if (!server.get_max_size(request, response)) {
|
|
1675
|
+
return;
|
|
1676
|
+
}
|
|
1451
1677
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1452
1678
|
return;
|
|
1453
1679
|
}
|
|
@@ -1563,45 +1789,77 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1563
1789
|
if (!recv_msg(sockfd, input)) {
|
|
1564
1790
|
return;
|
|
1565
1791
|
}
|
|
1566
|
-
|
|
1567
|
-
if (!server.graph_compute(input, response)) {
|
|
1792
|
+
if (!server.graph_compute(input)) {
|
|
1568
1793
|
return;
|
|
1569
1794
|
}
|
|
1570
|
-
|
|
1795
|
+
break;
|
|
1796
|
+
}
|
|
1797
|
+
case RPC_CMD_GRAPH_RECOMPUTE: {
|
|
1798
|
+
rpc_msg_graph_recompute_req request;
|
|
1799
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1800
|
+
return;
|
|
1801
|
+
}
|
|
1802
|
+
if (!server.graph_recompute(request)) {
|
|
1571
1803
|
return;
|
|
1572
1804
|
}
|
|
1573
1805
|
break;
|
|
1574
1806
|
}
|
|
1575
1807
|
case RPC_CMD_GET_DEVICE_MEMORY: {
|
|
1576
|
-
|
|
1808
|
+
rpc_msg_get_device_memory_req request;
|
|
1809
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1577
1810
|
return;
|
|
1578
1811
|
}
|
|
1579
1812
|
rpc_msg_get_device_memory_rsp response;
|
|
1580
|
-
|
|
1581
|
-
|
|
1813
|
+
if (!server.get_device_memory(request, response)) {
|
|
1814
|
+
return;
|
|
1815
|
+
}
|
|
1582
1816
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1583
1817
|
return;
|
|
1584
1818
|
}
|
|
1585
1819
|
break;
|
|
1586
1820
|
}
|
|
1587
1821
|
default: {
|
|
1588
|
-
|
|
1822
|
+
GGML_LOG_ERROR("Unknown command: %d\n", cmd);
|
|
1589
1823
|
return;
|
|
1590
1824
|
}
|
|
1591
1825
|
}
|
|
1592
1826
|
}
|
|
1593
1827
|
}
|
|
1594
1828
|
|
|
1595
|
-
void ggml_backend_rpc_start_server(
|
|
1596
|
-
|
|
1597
|
-
|
|
1829
|
+
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
|
1830
|
+
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
|
|
1831
|
+
if (n_devices == 0 || devices == nullptr) {
|
|
1832
|
+
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
|
1833
|
+
return;
|
|
1834
|
+
}
|
|
1835
|
+
std::vector<ggml_backend_t> backends;
|
|
1598
1836
|
printf("Starting RPC server v%d.%d.%d\n",
|
|
1599
1837
|
RPC_PROTO_MAJOR_VERSION,
|
|
1600
1838
|
RPC_PROTO_MINOR_VERSION,
|
|
1601
1839
|
RPC_PROTO_PATCH_VERSION);
|
|
1602
1840
|
printf(" endpoint : %s\n", endpoint);
|
|
1603
1841
|
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
|
1604
|
-
printf("
|
|
1842
|
+
printf("Devices:\n");
|
|
1843
|
+
for (size_t i = 0; i < n_devices; i++) {
|
|
1844
|
+
auto dev = devices[i];
|
|
1845
|
+
size_t free, total;
|
|
1846
|
+
ggml_backend_dev_memory(dev, &free, &total);
|
|
1847
|
+
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
|
1848
|
+
total / 1024 / 1024, free / 1024 / 1024);
|
|
1849
|
+
auto backend = ggml_backend_dev_init(dev, nullptr);
|
|
1850
|
+
if (!backend) {
|
|
1851
|
+
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
|
1852
|
+
return;
|
|
1853
|
+
}
|
|
1854
|
+
backends.push_back(backend);
|
|
1855
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
1856
|
+
if (reg) {
|
|
1857
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
1858
|
+
if (ggml_backend_set_n_threads_fn) {
|
|
1859
|
+
ggml_backend_set_n_threads_fn(backend, n_threads);
|
|
1860
|
+
}
|
|
1861
|
+
}
|
|
1862
|
+
}
|
|
1605
1863
|
|
|
1606
1864
|
std::string host;
|
|
1607
1865
|
int port;
|
|
@@ -1629,22 +1887,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
|
|
|
1629
1887
|
fprintf(stderr, "Failed to accept client connection\n");
|
|
1630
1888
|
return;
|
|
1631
1889
|
}
|
|
1632
|
-
printf("Accepted client connection
|
|
1890
|
+
printf("Accepted client connection\n");
|
|
1633
1891
|
fflush(stdout);
|
|
1634
|
-
rpc_serve_client(
|
|
1892
|
+
rpc_serve_client(backends, cache_dir, client_socket->fd);
|
|
1635
1893
|
printf("Client connection closed\n");
|
|
1636
1894
|
fflush(stdout);
|
|
1637
1895
|
}
|
|
1638
1896
|
#ifdef _WIN32
|
|
1639
1897
|
WSACleanup();
|
|
1640
1898
|
#endif
|
|
1899
|
+
for (auto backend : backends) {
|
|
1900
|
+
ggml_backend_free(backend);
|
|
1901
|
+
}
|
|
1641
1902
|
}
|
|
1642
1903
|
|
|
1643
1904
|
// device interface
|
|
1644
1905
|
|
|
1645
1906
|
struct ggml_backend_rpc_device_context {
|
|
1646
1907
|
std::string endpoint;
|
|
1908
|
+
uint32_t device;
|
|
1647
1909
|
std::string name;
|
|
1910
|
+
std::string description;
|
|
1648
1911
|
};
|
|
1649
1912
|
|
|
1650
1913
|
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
@@ -1656,15 +1919,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
|
1656
1919
|
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
|
1657
1920
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1658
1921
|
|
|
1659
|
-
return ctx->
|
|
1922
|
+
return ctx->description.c_str();
|
|
1660
1923
|
}
|
|
1661
1924
|
|
|
1662
1925
|
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
1663
1926
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1664
1927
|
|
|
1665
|
-
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
|
1666
|
-
|
|
1667
|
-
GGML_UNUSED(dev);
|
|
1928
|
+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
|
|
1668
1929
|
}
|
|
1669
1930
|
|
|
1670
1931
|
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
|
@@ -1690,7 +1951,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
|
|
|
1690
1951
|
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
1691
1952
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1692
1953
|
|
|
1693
|
-
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
|
1954
|
+
return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
|
|
1694
1955
|
|
|
1695
1956
|
GGML_UNUSED(params);
|
|
1696
1957
|
}
|
|
@@ -1698,7 +1959,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
|
|
|
1698
1959
|
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
1699
1960
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1700
1961
|
|
|
1701
|
-
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
1962
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
|
|
1702
1963
|
|
|
1703
1964
|
GGML_UNUSED(dev);
|
|
1704
1965
|
}
|
|
@@ -1716,7 +1977,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
|
|
|
1716
1977
|
}
|
|
1717
1978
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
1718
1979
|
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1719
|
-
return buft_ctx->endpoint == dev_ctx->endpoint;
|
|
1980
|
+
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
|
|
1720
1981
|
}
|
|
1721
1982
|
|
|
1722
1983
|
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
@@ -1739,28 +2000,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
|
1739
2000
|
|
|
1740
2001
|
// backend reg interface
|
|
1741
2002
|
|
|
1742
|
-
|
|
1743
|
-
|
|
2003
|
+
struct ggml_backend_rpc_reg_context {
|
|
2004
|
+
std::string name;
|
|
2005
|
+
std::vector<ggml_backend_dev_t> devices;
|
|
2006
|
+
};
|
|
1744
2007
|
|
|
1745
|
-
|
|
2008
|
+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
|
2009
|
+
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
|
2010
|
+
return ctx ? ctx->name.c_str() : "RPC";
|
|
1746
2011
|
}
|
|
1747
2012
|
|
|
1748
2013
|
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
GGML_UNUSED(reg);
|
|
2014
|
+
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
|
2015
|
+
return ctx ? ctx->devices.size() : 0;
|
|
1752
2016
|
}
|
|
1753
2017
|
|
|
1754
2018
|
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
2019
|
+
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
|
2020
|
+
if (ctx == nullptr) {
|
|
2021
|
+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
|
|
2022
|
+
} else {
|
|
2023
|
+
GGML_ASSERT(index < ctx->devices.size());
|
|
2024
|
+
return ctx->devices[index];
|
|
2025
|
+
}
|
|
1759
2026
|
}
|
|
1760
2027
|
|
|
1761
2028
|
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
|
1762
|
-
if (std::strcmp(name, "
|
|
1763
|
-
return (void *)
|
|
2029
|
+
if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
|
|
2030
|
+
return (void *)ggml_backend_rpc_add_server;
|
|
1764
2031
|
}
|
|
1765
2032
|
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
|
|
1766
2033
|
return (void *)ggml_backend_rpc_start_server;
|
|
@@ -1787,30 +2054,65 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
|
1787
2054
|
return &ggml_backend_rpc_reg;
|
|
1788
2055
|
}
|
|
1789
2056
|
|
|
1790
|
-
|
|
1791
|
-
|
|
2057
|
+
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
|
|
2058
|
+
auto sock = get_socket(endpoint);
|
|
2059
|
+
if (sock == nullptr) {
|
|
2060
|
+
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
|
2061
|
+
return 0;
|
|
2062
|
+
}
|
|
2063
|
+
rpc_msg_device_count_rsp response;
|
|
2064
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
|
|
2065
|
+
RPC_STATUS_ASSERT(status);
|
|
2066
|
+
return response.device_count;
|
|
2067
|
+
}
|
|
2068
|
+
|
|
2069
|
+
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
|
|
2070
|
+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
|
2071
|
+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
|
2072
|
+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
|
2073
|
+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
|
2074
|
+
};
|
|
1792
2075
|
|
|
2076
|
+
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
|
|
2077
|
+
static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
|
|
1793
2078
|
static std::mutex mutex;
|
|
2079
|
+
static uint32_t dev_id = 0;
|
|
1794
2080
|
std::lock_guard<std::mutex> lock(mutex);
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
return dev_map[endpoint];
|
|
2081
|
+
if (reg_map.find(endpoint) != reg_map.end()) {
|
|
2082
|
+
return reg_map[endpoint];
|
|
1798
2083
|
}
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
|
|
1802
|
-
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
2084
|
+
uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
|
|
2085
|
+
if (dev_count == 0) {
|
|
2086
|
+
return nullptr;
|
|
2087
|
+
}
|
|
2088
|
+
ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
|
|
2089
|
+
ctx->name = "RPC[" + std::string(endpoint) + "]";
|
|
2090
|
+
for (uint32_t ind = 0; ind < dev_count; ind++) {
|
|
2091
|
+
std::string dev_name = "RPC" + std::to_string(dev_id);
|
|
2092
|
+
std::string dev_desc = std::string(endpoint);
|
|
2093
|
+
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
|
|
2094
|
+
/* .endpoint = */ endpoint,
|
|
2095
|
+
/* .device = */ ind,
|
|
2096
|
+
/* .name = */ dev_name,
|
|
2097
|
+
/* .description = */ dev_desc
|
|
2098
|
+
};
|
|
2099
|
+
|
|
2100
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
|
2101
|
+
/* .iface = */ ggml_backend_rpc_device_i,
|
|
2102
|
+
/* .reg = */ ggml_backend_rpc_reg(),
|
|
2103
|
+
/* .context = */ dev_ctx,
|
|
2104
|
+
};
|
|
2105
|
+
ctx->devices.push_back(dev);
|
|
2106
|
+
dev_id++;
|
|
2107
|
+
}
|
|
2108
|
+
ggml_backend_reg_t reg = new ggml_backend_reg {
|
|
2109
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
2110
|
+
/* .iface = */ ggml_backend_rpc_reg_interface,
|
|
2111
|
+
/* .context = */ ctx
|
|
1809
2112
|
};
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
return dev;
|
|
2113
|
+
reg_map[endpoint] = reg;
|
|
2114
|
+
return reg;
|
|
1814
2115
|
}
|
|
1815
2116
|
|
|
2117
|
+
|
|
1816
2118
|
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
|