whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -105,9 +105,13 @@ enum rpc_cmd {
|
|
|
105
105
|
RPC_CMD_INIT_TENSOR,
|
|
106
106
|
RPC_CMD_GET_ALLOC_SIZE,
|
|
107
107
|
RPC_CMD_HELLO,
|
|
108
|
+
RPC_CMD_DEVICE_COUNT,
|
|
109
|
+
RPC_CMD_GRAPH_RECOMPUTE,
|
|
108
110
|
RPC_CMD_COUNT,
|
|
109
111
|
};
|
|
110
112
|
|
|
113
|
+
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
|
|
114
|
+
|
|
111
115
|
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
|
112
116
|
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
|
113
117
|
|
|
@@ -117,8 +121,14 @@ struct rpc_msg_hello_rsp {
|
|
|
117
121
|
uint8_t patch;
|
|
118
122
|
};
|
|
119
123
|
|
|
124
|
+
struct rpc_msg_device_count_rsp {
|
|
125
|
+
uint32_t device_count;
|
|
126
|
+
};
|
|
127
|
+
|
|
120
128
|
struct rpc_msg_get_alloc_size_req {
|
|
129
|
+
uint32_t device;
|
|
121
130
|
rpc_tensor tensor;
|
|
131
|
+
rpc_tensor srcs[GGML_MAX_SRC];
|
|
122
132
|
};
|
|
123
133
|
|
|
124
134
|
struct rpc_msg_get_alloc_size_rsp {
|
|
@@ -130,6 +140,7 @@ struct rpc_msg_init_tensor_req {
|
|
|
130
140
|
};
|
|
131
141
|
|
|
132
142
|
struct rpc_msg_alloc_buffer_req {
|
|
143
|
+
uint32_t device;
|
|
133
144
|
uint64_t size;
|
|
134
145
|
};
|
|
135
146
|
|
|
@@ -138,10 +149,18 @@ struct rpc_msg_alloc_buffer_rsp {
|
|
|
138
149
|
uint64_t remote_size;
|
|
139
150
|
};
|
|
140
151
|
|
|
152
|
+
struct rpc_msg_get_alignment_req {
|
|
153
|
+
uint32_t device;
|
|
154
|
+
};
|
|
155
|
+
|
|
141
156
|
struct rpc_msg_get_alignment_rsp {
|
|
142
157
|
uint64_t alignment;
|
|
143
158
|
};
|
|
144
159
|
|
|
160
|
+
struct rpc_msg_get_max_size_req {
|
|
161
|
+
uint32_t device;
|
|
162
|
+
};
|
|
163
|
+
|
|
145
164
|
struct rpc_msg_get_max_size_rsp {
|
|
146
165
|
uint64_t max_size;
|
|
147
166
|
};
|
|
@@ -188,14 +207,19 @@ struct rpc_msg_copy_tensor_rsp {
|
|
|
188
207
|
uint8_t result;
|
|
189
208
|
};
|
|
190
209
|
|
|
191
|
-
struct
|
|
192
|
-
|
|
210
|
+
struct rpc_msg_get_device_memory_req {
|
|
211
|
+
uint32_t device;
|
|
193
212
|
};
|
|
194
213
|
|
|
195
214
|
struct rpc_msg_get_device_memory_rsp {
|
|
196
215
|
uint64_t free_mem;
|
|
197
216
|
uint64_t total_mem;
|
|
198
217
|
};
|
|
218
|
+
|
|
219
|
+
struct rpc_msg_graph_recompute_req {
|
|
220
|
+
uint32_t device;
|
|
221
|
+
};
|
|
222
|
+
|
|
199
223
|
#pragma pack(pop)
|
|
200
224
|
|
|
201
225
|
// RPC data structures
|
|
@@ -207,14 +231,41 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|
|
207
231
|
|
|
208
232
|
struct ggml_backend_rpc_buffer_type_context {
|
|
209
233
|
std::string endpoint;
|
|
234
|
+
uint32_t device;
|
|
210
235
|
std::string name;
|
|
211
|
-
size_t
|
|
212
|
-
size_t
|
|
236
|
+
size_t alignment;
|
|
237
|
+
size_t max_size;
|
|
238
|
+
};
|
|
239
|
+
|
|
240
|
+
struct graph_cache {
|
|
241
|
+
|
|
242
|
+
bool is_cached(const ggml_cgraph * cgraph) {
|
|
243
|
+
if ((int)last_graph.size() != cgraph->n_nodes) {
|
|
244
|
+
return false;
|
|
245
|
+
}
|
|
246
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
247
|
+
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
|
248
|
+
return false;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
return true;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
void add(const ggml_cgraph * cgraph) {
|
|
255
|
+
last_graph.resize(cgraph->n_nodes);
|
|
256
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
257
|
+
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
std::vector<ggml_tensor> last_graph;
|
|
213
262
|
};
|
|
214
263
|
|
|
215
264
|
struct ggml_backend_rpc_context {
|
|
216
265
|
std::string endpoint;
|
|
266
|
+
uint32_t device;
|
|
217
267
|
std::string name;
|
|
268
|
+
graph_cache gc;
|
|
218
269
|
};
|
|
219
270
|
|
|
220
271
|
struct ggml_backend_rpc_buffer_context {
|
|
@@ -473,6 +524,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
473
524
|
std::string host;
|
|
474
525
|
int port;
|
|
475
526
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
527
|
+
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
|
|
476
528
|
return nullptr;
|
|
477
529
|
}
|
|
478
530
|
#ifdef _WIN32
|
|
@@ -520,14 +572,23 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
520
572
|
return ctx->base_ptr;
|
|
521
573
|
}
|
|
522
574
|
|
|
575
|
+
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
|
|
576
|
+
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
|
|
577
|
+
}
|
|
578
|
+
|
|
523
579
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
524
580
|
rpc_tensor result;
|
|
581
|
+
if (!tensor) {
|
|
582
|
+
memset(&result, 0, sizeof(result));
|
|
583
|
+
return result;
|
|
584
|
+
}
|
|
585
|
+
|
|
525
586
|
result.id = reinterpret_cast<uint64_t>(tensor);
|
|
526
587
|
result.type = tensor->type;
|
|
527
|
-
if (tensor->buffer) {
|
|
588
|
+
if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
|
|
528
589
|
ggml_backend_buffer_t buffer = tensor->buffer;
|
|
529
590
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
530
|
-
result.buffer = ctx->remote_ptr;
|
|
591
|
+
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
|
|
531
592
|
} else {
|
|
532
593
|
result.buffer = 0;
|
|
533
594
|
}
|
|
@@ -609,22 +670,25 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
|
|
609
670
|
}
|
|
610
671
|
|
|
611
672
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
673
|
+
if (ggml_backend_buffer_is_rpc(src->buffer)) {
|
|
674
|
+
// check if src and dst are on the same server
|
|
675
|
+
ggml_backend_buffer_t src_buffer = src->buffer;
|
|
676
|
+
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
|
677
|
+
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
|
678
|
+
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
|
679
|
+
if (src_ctx->sock != dst_ctx->sock) {
|
|
680
|
+
return false;
|
|
681
|
+
}
|
|
682
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
683
|
+
rpc_msg_copy_tensor_req request;
|
|
684
|
+
request.src = serialize_tensor(src);
|
|
685
|
+
request.dst = serialize_tensor(dst);
|
|
686
|
+
rpc_msg_copy_tensor_rsp response;
|
|
687
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
688
|
+
RPC_STATUS_ASSERT(status);
|
|
689
|
+
return response.result;
|
|
619
690
|
}
|
|
620
|
-
|
|
621
|
-
rpc_msg_copy_tensor_req request;
|
|
622
|
-
request.src = serialize_tensor(src);
|
|
623
|
-
request.dst = serialize_tensor(dst);
|
|
624
|
-
rpc_msg_copy_tensor_rsp response;
|
|
625
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
626
|
-
RPC_STATUS_ASSERT(status);
|
|
627
|
-
return response.result;
|
|
691
|
+
return false;
|
|
628
692
|
}
|
|
629
693
|
|
|
630
694
|
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
@@ -653,7 +717,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
|
|
653
717
|
|
|
654
718
|
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
655
719
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
656
|
-
rpc_msg_alloc_buffer_req request = {size};
|
|
720
|
+
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
|
|
657
721
|
rpc_msg_alloc_buffer_rsp response;
|
|
658
722
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
659
723
|
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
|
@@ -669,9 +733,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
|
669
733
|
}
|
|
670
734
|
}
|
|
671
735
|
|
|
672
|
-
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|
736
|
+
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
|
737
|
+
rpc_msg_get_alignment_req request = {device};
|
|
673
738
|
rpc_msg_get_alignment_rsp response;
|
|
674
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT,
|
|
739
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
|
|
675
740
|
RPC_STATUS_ASSERT(status);
|
|
676
741
|
return response.alignment;
|
|
677
742
|
}
|
|
@@ -681,9 +746,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|
|
681
746
|
return buft_ctx->alignment;
|
|
682
747
|
}
|
|
683
748
|
|
|
684
|
-
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|
749
|
+
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
|
750
|
+
rpc_msg_get_max_size_req request = {device};
|
|
685
751
|
rpc_msg_get_max_size_rsp response;
|
|
686
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE,
|
|
752
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
687
753
|
RPC_STATUS_ASSERT(status);
|
|
688
754
|
return response.max_size;
|
|
689
755
|
}
|
|
@@ -694,23 +760,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
|
694
760
|
}
|
|
695
761
|
|
|
696
762
|
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
|
763
|
+
// should we query the remote server for the actual size
|
|
764
|
+
bool rpc_get = false;
|
|
765
|
+
|
|
697
766
|
// See comments in init_tensor.
|
|
698
|
-
|
|
767
|
+
rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
|
|
768
|
+
|
|
769
|
+
// ops that require additional memory for fleeting data on certain backends
|
|
770
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
|
771
|
+
rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
|
|
772
|
+
rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
|
|
773
|
+
|
|
774
|
+
if (rpc_get) {
|
|
699
775
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
700
776
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
701
777
|
|
|
702
|
-
rpc_msg_get_alloc_size_req request
|
|
778
|
+
rpc_msg_get_alloc_size_req request = {
|
|
779
|
+
/*.device =*/ buft_ctx->device,
|
|
780
|
+
/*.tensor =*/ serialize_tensor(tensor),
|
|
781
|
+
/*.srcs =*/ {},
|
|
782
|
+
};
|
|
703
783
|
|
|
704
|
-
|
|
784
|
+
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
|
|
785
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
786
|
+
request.srcs[i] = serialize_tensor(tensor->src[i]);
|
|
787
|
+
}
|
|
705
788
|
|
|
789
|
+
// TODO: cache the alloc responses to avoid extra RPC calls?
|
|
706
790
|
rpc_msg_get_alloc_size_rsp response;
|
|
707
791
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
708
792
|
RPC_STATUS_ASSERT(status);
|
|
709
793
|
|
|
710
794
|
return response.alloc_size;
|
|
711
|
-
} else {
|
|
712
|
-
return ggml_nbytes(tensor);
|
|
713
795
|
}
|
|
796
|
+
|
|
797
|
+
return ggml_nbytes(tensor);
|
|
714
798
|
}
|
|
715
799
|
|
|
716
800
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
@@ -754,7 +838,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
|
|
|
754
838
|
tensors.push_back(serialize_tensor(tensor));
|
|
755
839
|
}
|
|
756
840
|
|
|
757
|
-
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
|
841
|
+
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
|
758
842
|
uint32_t n_nodes = cgraph->n_nodes;
|
|
759
843
|
std::vector<rpc_tensor> tensors;
|
|
760
844
|
std::unordered_set<ggml_tensor*> visited;
|
|
@@ -762,29 +846,45 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
|
|
|
762
846
|
add_tensor(cgraph->nodes[i], tensors, visited);
|
|
763
847
|
}
|
|
764
848
|
// serialization format:
|
|
765
|
-
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
849
|
+
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
766
850
|
uint32_t n_tensors = tensors.size();
|
|
767
|
-
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
|
851
|
+
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
|
768
852
|
output.resize(output_size, 0);
|
|
769
|
-
|
|
853
|
+
uint8_t * dest = output.data();
|
|
854
|
+
memcpy(dest, &device, sizeof(device));
|
|
855
|
+
dest += sizeof(device);
|
|
856
|
+
memcpy(dest, &n_nodes, sizeof(n_nodes));
|
|
857
|
+
dest += sizeof(n_nodes);
|
|
770
858
|
for (uint32_t i = 0; i < n_nodes; i++) {
|
|
771
|
-
memcpy(
|
|
859
|
+
memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
|
772
860
|
}
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
861
|
+
dest += n_nodes * sizeof(uint64_t);
|
|
862
|
+
memcpy(dest, &n_tensors, sizeof(n_tensors));
|
|
863
|
+
dest += sizeof(n_tensors);
|
|
864
|
+
rpc_tensor * out_tensors = (rpc_tensor *)dest;
|
|
776
865
|
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
|
777
866
|
}
|
|
778
867
|
|
|
779
868
|
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
780
869
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
870
|
+
|
|
871
|
+
GGML_ASSERT(cgraph->n_nodes > 0);
|
|
872
|
+
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
|
873
|
+
if (reuse) {
|
|
874
|
+
rpc_msg_graph_recompute_req request;
|
|
875
|
+
request.device = rpc_ctx->device;
|
|
876
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
|
877
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
|
878
|
+
RPC_STATUS_ASSERT(status);
|
|
879
|
+
} else {
|
|
880
|
+
rpc_ctx->gc.add(cgraph);
|
|
881
|
+
std::vector<uint8_t> input;
|
|
882
|
+
serialize_graph(rpc_ctx->device, cgraph, input);
|
|
883
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
|
884
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
|
|
885
|
+
RPC_STATUS_ASSERT(status);
|
|
886
|
+
}
|
|
887
|
+
return GGML_STATUS_SUCCESS;
|
|
788
888
|
}
|
|
789
889
|
|
|
790
890
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
@@ -804,12 +904,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
804
904
|
/* .graph_optimize = */ NULL,
|
|
805
905
|
};
|
|
806
906
|
|
|
807
|
-
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
907
|
+
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
|
|
808
908
|
static std::mutex mutex;
|
|
809
909
|
std::lock_guard<std::mutex> lock(mutex);
|
|
910
|
+
std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
|
810
911
|
// NOTE: buffer types are allocated and never freed; this is by design
|
|
811
912
|
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
|
812
|
-
auto it = buft_map.find(
|
|
913
|
+
auto it = buft_map.find(buft_name);
|
|
813
914
|
if (it != buft_map.end()) {
|
|
814
915
|
return it->second;
|
|
815
916
|
}
|
|
@@ -818,34 +919,38 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
|
818
919
|
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
|
819
920
|
return nullptr;
|
|
820
921
|
}
|
|
821
|
-
size_t alignment = get_alignment(sock);
|
|
822
|
-
size_t max_size = get_max_size(sock);
|
|
922
|
+
size_t alignment = get_alignment(sock, device);
|
|
923
|
+
size_t max_size = get_max_size(sock, device);
|
|
823
924
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
|
824
925
|
/* .endpoint = */ endpoint,
|
|
825
|
-
/* .
|
|
926
|
+
/* .device = */ device,
|
|
927
|
+
/* .name = */ buft_name,
|
|
826
928
|
/* .alignment = */ alignment,
|
|
827
929
|
/* .max_size = */ max_size
|
|
828
930
|
};
|
|
829
|
-
|
|
931
|
+
auto reg = ggml_backend_rpc_add_server(endpoint);
|
|
830
932
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
831
933
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
832
|
-
/* .device = */
|
|
934
|
+
/* .device = */ ggml_backend_reg_dev_get(reg, device),
|
|
833
935
|
/* .context = */ buft_ctx
|
|
834
936
|
};
|
|
835
|
-
buft_map[
|
|
937
|
+
buft_map[buft_name] = buft;
|
|
836
938
|
return buft;
|
|
837
939
|
}
|
|
838
940
|
|
|
839
|
-
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
941
|
+
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
|
942
|
+
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
|
840
943
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
841
|
-
/* .endpoint
|
|
842
|
-
/* .
|
|
944
|
+
/* .endpoint = */ endpoint,
|
|
945
|
+
/* .device = */ device,
|
|
946
|
+
/* .name = */ dev_name,
|
|
947
|
+
/* .gc = */ {},
|
|
843
948
|
};
|
|
844
|
-
|
|
949
|
+
auto reg = ggml_backend_rpc_add_server(endpoint);
|
|
845
950
|
ggml_backend_t backend = new ggml_backend {
|
|
846
951
|
/* .guid = */ ggml_backend_rpc_guid(),
|
|
847
952
|
/* .iface = */ ggml_backend_rpc_interface,
|
|
848
|
-
/* .device = */
|
|
953
|
+
/* .device = */ ggml_backend_reg_dev_get(reg, device),
|
|
849
954
|
/* .context = */ ctx
|
|
850
955
|
};
|
|
851
956
|
return backend;
|
|
@@ -855,37 +960,40 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
|
855
960
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
|
856
961
|
}
|
|
857
962
|
|
|
858
|
-
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
|
963
|
+
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
|
|
964
|
+
rpc_msg_get_device_memory_req request;
|
|
965
|
+
request.device = device;
|
|
859
966
|
rpc_msg_get_device_memory_rsp response;
|
|
860
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY,
|
|
967
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
|
|
861
968
|
RPC_STATUS_ASSERT(status);
|
|
862
969
|
*free = response.free_mem;
|
|
863
970
|
*total = response.total_mem;
|
|
864
971
|
}
|
|
865
972
|
|
|
866
|
-
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
|
973
|
+
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
|
|
867
974
|
auto sock = get_socket(endpoint);
|
|
868
975
|
if (sock == nullptr) {
|
|
869
976
|
*free = 0;
|
|
870
977
|
*total = 0;
|
|
871
978
|
return;
|
|
872
979
|
}
|
|
873
|
-
get_device_memory(sock, free, total);
|
|
980
|
+
get_device_memory(sock, device, free, total);
|
|
874
981
|
}
|
|
875
982
|
|
|
876
983
|
// RPC server-side implementation
|
|
877
984
|
|
|
878
985
|
class rpc_server {
|
|
879
986
|
public:
|
|
880
|
-
rpc_server(ggml_backend_t
|
|
881
|
-
:
|
|
987
|
+
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
|
|
988
|
+
: backends(std::move(all_backends)), cache_dir(cache_dir) {
|
|
989
|
+
stored_graphs.resize(backends.size());
|
|
882
990
|
}
|
|
883
991
|
~rpc_server();
|
|
884
992
|
|
|
885
993
|
void hello(rpc_msg_hello_rsp & response);
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
994
|
+
bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
995
|
+
bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
|
|
996
|
+
bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
|
|
889
997
|
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
|
890
998
|
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
891
999
|
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
@@ -893,9 +1001,16 @@ public:
|
|
|
893
1001
|
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
|
894
1002
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
895
1003
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
896
|
-
bool graph_compute(const std::vector<uint8_t> & input
|
|
1004
|
+
bool graph_compute(const std::vector<uint8_t> & input);
|
|
1005
|
+
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
|
|
897
1006
|
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
|
898
1007
|
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
|
1008
|
+
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
|
1009
|
+
|
|
1010
|
+
struct stored_graph {
|
|
1011
|
+
ggml_context_ptr ctx_ptr;
|
|
1012
|
+
ggml_cgraph * graph;
|
|
1013
|
+
};
|
|
899
1014
|
|
|
900
1015
|
private:
|
|
901
1016
|
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
|
@@ -906,9 +1021,11 @@ private:
|
|
|
906
1021
|
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
|
|
907
1022
|
|
|
908
1023
|
|
|
909
|
-
ggml_backend_t
|
|
1024
|
+
std::vector<ggml_backend_t> backends;
|
|
910
1025
|
const char * cache_dir;
|
|
911
1026
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
1027
|
+
// store the last computed graph for each backend
|
|
1028
|
+
std::vector<stored_graph> stored_graphs;
|
|
912
1029
|
};
|
|
913
1030
|
|
|
914
1031
|
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
|
@@ -919,9 +1036,13 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
|
|
919
1036
|
}
|
|
920
1037
|
|
|
921
1038
|
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
|
1039
|
+
uint32_t dev_id = request.device;
|
|
1040
|
+
if (dev_id >= backends.size()) {
|
|
1041
|
+
return false;
|
|
1042
|
+
}
|
|
922
1043
|
ggml_backend_buffer_type_t buft;
|
|
923
1044
|
struct ggml_init_params params {
|
|
924
|
-
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
1045
|
+
/*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
|
|
925
1046
|
/*.mem_buffer =*/ NULL,
|
|
926
1047
|
/*.no_alloc =*/ true,
|
|
927
1048
|
};
|
|
@@ -929,16 +1050,22 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
929
1050
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
930
1051
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
931
1052
|
ggml_context * ctx = ctx_ptr.get();
|
|
932
|
-
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
933
1053
|
|
|
1054
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
934
1055
|
if (tensor == nullptr) {
|
|
935
1056
|
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
|
936
1057
|
return false;
|
|
937
1058
|
}
|
|
938
|
-
|
|
1059
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
1060
|
+
if (request.srcs[i].id != 0) {
|
|
1061
|
+
tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
|
|
939
1066
|
if (tensor->buffer == nullptr) {
|
|
940
1067
|
//No buffer allocated.
|
|
941
|
-
buft = ggml_backend_get_default_buffer_type(
|
|
1068
|
+
buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
942
1069
|
} else {
|
|
943
1070
|
buft = tensor->buffer->buft;
|
|
944
1071
|
}
|
|
@@ -948,33 +1075,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
948
1075
|
return true;
|
|
949
1076
|
}
|
|
950
1077
|
|
|
951
|
-
|
|
952
|
-
|
|
1078
|
+
bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
|
1079
|
+
uint32_t dev_id = request.device;
|
|
1080
|
+
if (dev_id >= backends.size()) {
|
|
1081
|
+
return false;
|
|
1082
|
+
}
|
|
1083
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
953
1084
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
|
954
1085
|
response.remote_ptr = 0;
|
|
955
1086
|
response.remote_size = 0;
|
|
956
1087
|
if (buffer != nullptr) {
|
|
957
1088
|
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
|
958
1089
|
response.remote_size = buffer->size;
|
|
959
|
-
LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
|
|
1090
|
+
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
|
|
1091
|
+
__func__, dev_id, request.size, response.remote_ptr, response.remote_size);
|
|
960
1092
|
buffers.insert(buffer);
|
|
961
1093
|
} else {
|
|
962
|
-
LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
|
1094
|
+
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
|
|
963
1095
|
}
|
|
1096
|
+
return true;
|
|
964
1097
|
}
|
|
965
1098
|
|
|
966
|
-
|
|
967
|
-
|
|
1099
|
+
bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
|
|
1100
|
+
uint32_t dev_id = request.device;
|
|
1101
|
+
if (dev_id >= backends.size()) {
|
|
1102
|
+
return false;
|
|
1103
|
+
}
|
|
1104
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
968
1105
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
|
969
|
-
LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
|
|
1106
|
+
LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
|
|
970
1107
|
response.alignment = alignment;
|
|
1108
|
+
return true;
|
|
971
1109
|
}
|
|
972
1110
|
|
|
973
|
-
|
|
974
|
-
|
|
1111
|
+
bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
|
|
1112
|
+
uint32_t dev_id = request.device;
|
|
1113
|
+
if (dev_id >= backends.size()) {
|
|
1114
|
+
return false;
|
|
1115
|
+
}
|
|
1116
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
|
975
1117
|
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
|
976
|
-
LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
|
|
1118
|
+
LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
|
|
977
1119
|
response.max_size = max_size;
|
|
1120
|
+
return true;
|
|
978
1121
|
}
|
|
979
1122
|
|
|
980
1123
|
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
|
@@ -1115,7 +1258,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|
|
1115
1258
|
char hash_str[17];
|
|
1116
1259
|
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
|
1117
1260
|
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1118
|
-
|
|
1261
|
+
std::error_code ec;
|
|
1262
|
+
if (!fs::exists(cache_file, ec)) {
|
|
1119
1263
|
return false;
|
|
1120
1264
|
}
|
|
1121
1265
|
std::ifstream ifs(cache_file, std::ios::binary);
|
|
@@ -1330,25 +1474,35 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|
|
1330
1474
|
return result;
|
|
1331
1475
|
}
|
|
1332
1476
|
|
|
1333
|
-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input
|
|
1477
|
+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
|
1334
1478
|
// serialization format:
|
|
1335
|
-
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
1336
|
-
if (input.size() < sizeof(uint32_t)) {
|
|
1479
|
+
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
1480
|
+
if (input.size() < 2*sizeof(uint32_t)) {
|
|
1481
|
+
return false;
|
|
1482
|
+
}
|
|
1483
|
+
const uint8_t * src = input.data();
|
|
1484
|
+
uint32_t device;
|
|
1485
|
+
memcpy(&device, src, sizeof(device));
|
|
1486
|
+
src += sizeof(device);
|
|
1487
|
+
if (device >= backends.size()) {
|
|
1337
1488
|
return false;
|
|
1338
1489
|
}
|
|
1339
1490
|
uint32_t n_nodes;
|
|
1340
|
-
memcpy(&n_nodes,
|
|
1341
|
-
|
|
1491
|
+
memcpy(&n_nodes, src, sizeof(n_nodes));
|
|
1492
|
+
src += sizeof(n_nodes);
|
|
1493
|
+
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
|
|
1342
1494
|
return false;
|
|
1343
1495
|
}
|
|
1344
|
-
const uint64_t * nodes = (const uint64_t *)
|
|
1496
|
+
const uint64_t * nodes = (const uint64_t *)src;
|
|
1497
|
+
src += n_nodes*sizeof(uint64_t);
|
|
1345
1498
|
uint32_t n_tensors;
|
|
1346
|
-
memcpy(&n_tensors,
|
|
1347
|
-
|
|
1499
|
+
memcpy(&n_tensors, src, sizeof(n_tensors));
|
|
1500
|
+
src += sizeof(n_tensors);
|
|
1501
|
+
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
|
|
1348
1502
|
return false;
|
|
1349
1503
|
}
|
|
1350
|
-
const rpc_tensor * tensors = (const rpc_tensor *)
|
|
1351
|
-
LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
|
1504
|
+
const rpc_tensor * tensors = (const rpc_tensor *)src;
|
|
1505
|
+
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
|
1352
1506
|
|
|
1353
1507
|
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
|
1354
1508
|
|
|
@@ -1363,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1363
1517
|
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
|
1364
1518
|
graph->n_nodes = n_nodes;
|
|
1365
1519
|
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
|
1520
|
+
tensor_ptrs.reserve(n_tensors);
|
|
1366
1521
|
for (uint32_t i = 0; i < n_tensors; i++) {
|
|
1367
|
-
tensor_ptrs
|
|
1522
|
+
tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
|
|
1368
1523
|
}
|
|
1369
1524
|
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
|
|
1525
|
+
tensor_map.reserve(n_nodes);
|
|
1370
1526
|
for (uint32_t i = 0; i < n_nodes; i++) {
|
|
1371
1527
|
int64_t id;
|
|
1372
1528
|
memcpy(&id, &nodes[i], sizeof(id));
|
|
@@ -1380,8 +1536,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1380
1536
|
return false;
|
|
1381
1537
|
}
|
|
1382
1538
|
}
|
|
1383
|
-
ggml_status status = ggml_backend_graph_compute(
|
|
1384
|
-
|
|
1539
|
+
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
|
1540
|
+
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
|
1541
|
+
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
|
1542
|
+
stored_graphs[device].graph = graph;
|
|
1543
|
+
return true;
|
|
1544
|
+
}
|
|
1545
|
+
|
|
1546
|
+
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
|
|
1547
|
+
uint32_t device = request.device;
|
|
1548
|
+
if (device >= backends.size()) {
|
|
1549
|
+
return false;
|
|
1550
|
+
}
|
|
1551
|
+
if (stored_graphs[device].graph == nullptr) {
|
|
1552
|
+
return false;
|
|
1553
|
+
}
|
|
1554
|
+
ggml_cgraph * graph = stored_graphs[device].graph;
|
|
1555
|
+
LOG_DBG("[%s] device: %u\n", __func__, device);
|
|
1556
|
+
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
|
1557
|
+
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
|
1558
|
+
return true;
|
|
1559
|
+
}
|
|
1560
|
+
|
|
1561
|
+
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
|
|
1562
|
+
uint32_t dev_id = request.device;
|
|
1563
|
+
if (dev_id >= backends.size()) {
|
|
1564
|
+
return false;
|
|
1565
|
+
}
|
|
1566
|
+
size_t free, total;
|
|
1567
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
|
|
1568
|
+
ggml_backend_dev_memory(dev, &free, &total);
|
|
1569
|
+
response.free_mem = free;
|
|
1570
|
+
response.total_mem = total;
|
|
1571
|
+
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
|
|
1385
1572
|
return true;
|
|
1386
1573
|
}
|
|
1387
1574
|
|
|
@@ -1391,9 +1578,9 @@ rpc_server::~rpc_server() {
|
|
|
1391
1578
|
}
|
|
1392
1579
|
}
|
|
1393
1580
|
|
|
1394
|
-
static void rpc_serve_client(ggml_backend_t
|
|
1395
|
-
sockfd_t sockfd
|
|
1396
|
-
rpc_server server(
|
|
1581
|
+
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
|
1582
|
+
sockfd_t sockfd) {
|
|
1583
|
+
rpc_server server(backends, cache_dir);
|
|
1397
1584
|
uint8_t cmd;
|
|
1398
1585
|
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1399
1586
|
return;
|
|
@@ -1425,13 +1612,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1425
1612
|
// HELLO command is handled above
|
|
1426
1613
|
return;
|
|
1427
1614
|
}
|
|
1615
|
+
case RPC_CMD_DEVICE_COUNT: {
|
|
1616
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1617
|
+
return;
|
|
1618
|
+
}
|
|
1619
|
+
rpc_msg_device_count_rsp response;
|
|
1620
|
+
response.device_count = backends.size();
|
|
1621
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1622
|
+
return;
|
|
1623
|
+
}
|
|
1624
|
+
break;
|
|
1625
|
+
}
|
|
1428
1626
|
case RPC_CMD_ALLOC_BUFFER: {
|
|
1429
1627
|
rpc_msg_alloc_buffer_req request;
|
|
1430
1628
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1431
1629
|
return;
|
|
1432
1630
|
}
|
|
1433
1631
|
rpc_msg_alloc_buffer_rsp response;
|
|
1434
|
-
server.alloc_buffer(request, response)
|
|
1632
|
+
if (!server.alloc_buffer(request, response)) {
|
|
1633
|
+
return;
|
|
1634
|
+
}
|
|
1435
1635
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1436
1636
|
return;
|
|
1437
1637
|
}
|
|
@@ -1452,22 +1652,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1452
1652
|
break;
|
|
1453
1653
|
}
|
|
1454
1654
|
case RPC_CMD_GET_ALIGNMENT: {
|
|
1455
|
-
|
|
1655
|
+
rpc_msg_get_alignment_req request;
|
|
1656
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1456
1657
|
return;
|
|
1457
1658
|
}
|
|
1458
1659
|
rpc_msg_get_alignment_rsp response;
|
|
1459
|
-
server.get_alignment(response)
|
|
1660
|
+
if (!server.get_alignment(request, response)) {
|
|
1661
|
+
return;
|
|
1662
|
+
}
|
|
1460
1663
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1461
1664
|
return;
|
|
1462
1665
|
}
|
|
1463
1666
|
break;
|
|
1464
1667
|
}
|
|
1465
1668
|
case RPC_CMD_GET_MAX_SIZE: {
|
|
1466
|
-
|
|
1669
|
+
rpc_msg_get_max_size_req request;
|
|
1670
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1467
1671
|
return;
|
|
1468
1672
|
}
|
|
1469
1673
|
rpc_msg_get_max_size_rsp response;
|
|
1470
|
-
server.get_max_size(response)
|
|
1674
|
+
if (!server.get_max_size(request, response)) {
|
|
1675
|
+
return;
|
|
1676
|
+
}
|
|
1471
1677
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1472
1678
|
return;
|
|
1473
1679
|
}
|
|
@@ -1583,22 +1789,30 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1583
1789
|
if (!recv_msg(sockfd, input)) {
|
|
1584
1790
|
return;
|
|
1585
1791
|
}
|
|
1586
|
-
|
|
1587
|
-
if (!server.graph_compute(input, response)) {
|
|
1792
|
+
if (!server.graph_compute(input)) {
|
|
1588
1793
|
return;
|
|
1589
1794
|
}
|
|
1590
|
-
|
|
1795
|
+
break;
|
|
1796
|
+
}
|
|
1797
|
+
case RPC_CMD_GRAPH_RECOMPUTE: {
|
|
1798
|
+
rpc_msg_graph_recompute_req request;
|
|
1799
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1800
|
+
return;
|
|
1801
|
+
}
|
|
1802
|
+
if (!server.graph_recompute(request)) {
|
|
1591
1803
|
return;
|
|
1592
1804
|
}
|
|
1593
1805
|
break;
|
|
1594
1806
|
}
|
|
1595
1807
|
case RPC_CMD_GET_DEVICE_MEMORY: {
|
|
1596
|
-
|
|
1808
|
+
rpc_msg_get_device_memory_req request;
|
|
1809
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1597
1810
|
return;
|
|
1598
1811
|
}
|
|
1599
1812
|
rpc_msg_get_device_memory_rsp response;
|
|
1600
|
-
|
|
1601
|
-
|
|
1813
|
+
if (!server.get_device_memory(request, response)) {
|
|
1814
|
+
return;
|
|
1815
|
+
}
|
|
1602
1816
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1603
1817
|
return;
|
|
1604
1818
|
}
|
|
@@ -1612,16 +1826,40 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
1612
1826
|
}
|
|
1613
1827
|
}
|
|
1614
1828
|
|
|
1615
|
-
void ggml_backend_rpc_start_server(
|
|
1616
|
-
|
|
1617
|
-
|
|
1829
|
+
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
|
1830
|
+
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
|
|
1831
|
+
if (n_devices == 0 || devices == nullptr) {
|
|
1832
|
+
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
|
1833
|
+
return;
|
|
1834
|
+
}
|
|
1835
|
+
std::vector<ggml_backend_t> backends;
|
|
1618
1836
|
printf("Starting RPC server v%d.%d.%d\n",
|
|
1619
1837
|
RPC_PROTO_MAJOR_VERSION,
|
|
1620
1838
|
RPC_PROTO_MINOR_VERSION,
|
|
1621
1839
|
RPC_PROTO_PATCH_VERSION);
|
|
1622
1840
|
printf(" endpoint : %s\n", endpoint);
|
|
1623
1841
|
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
|
1624
|
-
printf("
|
|
1842
|
+
printf("Devices:\n");
|
|
1843
|
+
for (size_t i = 0; i < n_devices; i++) {
|
|
1844
|
+
auto dev = devices[i];
|
|
1845
|
+
size_t free, total;
|
|
1846
|
+
ggml_backend_dev_memory(dev, &free, &total);
|
|
1847
|
+
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
|
1848
|
+
total / 1024 / 1024, free / 1024 / 1024);
|
|
1849
|
+
auto backend = ggml_backend_dev_init(dev, nullptr);
|
|
1850
|
+
if (!backend) {
|
|
1851
|
+
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
|
1852
|
+
return;
|
|
1853
|
+
}
|
|
1854
|
+
backends.push_back(backend);
|
|
1855
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
1856
|
+
if (reg) {
|
|
1857
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
1858
|
+
if (ggml_backend_set_n_threads_fn) {
|
|
1859
|
+
ggml_backend_set_n_threads_fn(backend, n_threads);
|
|
1860
|
+
}
|
|
1861
|
+
}
|
|
1862
|
+
}
|
|
1625
1863
|
|
|
1626
1864
|
std::string host;
|
|
1627
1865
|
int port;
|
|
@@ -1649,22 +1887,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
|
|
|
1649
1887
|
fprintf(stderr, "Failed to accept client connection\n");
|
|
1650
1888
|
return;
|
|
1651
1889
|
}
|
|
1652
|
-
printf("Accepted client connection
|
|
1890
|
+
printf("Accepted client connection\n");
|
|
1653
1891
|
fflush(stdout);
|
|
1654
|
-
rpc_serve_client(
|
|
1892
|
+
rpc_serve_client(backends, cache_dir, client_socket->fd);
|
|
1655
1893
|
printf("Client connection closed\n");
|
|
1656
1894
|
fflush(stdout);
|
|
1657
1895
|
}
|
|
1658
1896
|
#ifdef _WIN32
|
|
1659
1897
|
WSACleanup();
|
|
1660
1898
|
#endif
|
|
1899
|
+
for (auto backend : backends) {
|
|
1900
|
+
ggml_backend_free(backend);
|
|
1901
|
+
}
|
|
1661
1902
|
}
|
|
1662
1903
|
|
|
1663
1904
|
// device interface
|
|
1664
1905
|
|
|
1665
1906
|
struct ggml_backend_rpc_device_context {
|
|
1666
1907
|
std::string endpoint;
|
|
1908
|
+
uint32_t device;
|
|
1667
1909
|
std::string name;
|
|
1910
|
+
std::string description;
|
|
1668
1911
|
};
|
|
1669
1912
|
|
|
1670
1913
|
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
@@ -1676,15 +1919,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
|
1676
1919
|
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
|
1677
1920
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1678
1921
|
|
|
1679
|
-
return ctx->
|
|
1922
|
+
return ctx->description.c_str();
|
|
1680
1923
|
}
|
|
1681
1924
|
|
|
1682
1925
|
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
1683
1926
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1684
1927
|
|
|
1685
|
-
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
|
1686
|
-
|
|
1687
|
-
GGML_UNUSED(dev);
|
|
1928
|
+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
|
|
1688
1929
|
}
|
|
1689
1930
|
|
|
1690
1931
|
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
|
@@ -1710,7 +1951,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
|
|
|
1710
1951
|
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
1711
1952
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1712
1953
|
|
|
1713
|
-
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
|
1954
|
+
return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
|
|
1714
1955
|
|
|
1715
1956
|
GGML_UNUSED(params);
|
|
1716
1957
|
}
|
|
@@ -1718,7 +1959,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
|
|
|
1718
1959
|
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
1719
1960
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1720
1961
|
|
|
1721
|
-
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
1962
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
|
|
1722
1963
|
|
|
1723
1964
|
GGML_UNUSED(dev);
|
|
1724
1965
|
}
|
|
@@ -1736,7 +1977,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
|
|
|
1736
1977
|
}
|
|
1737
1978
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
1738
1979
|
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1739
|
-
return buft_ctx->endpoint == dev_ctx->endpoint;
|
|
1980
|
+
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
|
|
1740
1981
|
}
|
|
1741
1982
|
|
|
1742
1983
|
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
@@ -1759,28 +2000,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
|
1759
2000
|
|
|
1760
2001
|
// backend reg interface
|
|
1761
2002
|
|
|
1762
|
-
|
|
1763
|
-
|
|
2003
|
+
struct ggml_backend_rpc_reg_context {
|
|
2004
|
+
std::string name;
|
|
2005
|
+
std::vector<ggml_backend_dev_t> devices;
|
|
2006
|
+
};
|
|
1764
2007
|
|
|
1765
|
-
|
|
2008
|
+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
|
2009
|
+
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
|
2010
|
+
return ctx ? ctx->name.c_str() : "RPC";
|
|
1766
2011
|
}
|
|
1767
2012
|
|
|
1768
2013
|
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
GGML_UNUSED(reg);
|
|
2014
|
+
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
|
2015
|
+
return ctx ? ctx->devices.size() : 0;
|
|
1772
2016
|
}
|
|
1773
2017
|
|
|
1774
2018
|
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
|
|
2019
|
+
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
|
2020
|
+
if (ctx == nullptr) {
|
|
2021
|
+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
|
|
2022
|
+
} else {
|
|
2023
|
+
GGML_ASSERT(index < ctx->devices.size());
|
|
2024
|
+
return ctx->devices[index];
|
|
2025
|
+
}
|
|
1779
2026
|
}
|
|
1780
2027
|
|
|
1781
2028
|
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
|
1782
|
-
if (std::strcmp(name, "
|
|
1783
|
-
return (void *)
|
|
2029
|
+
if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
|
|
2030
|
+
return (void *)ggml_backend_rpc_add_server;
|
|
1784
2031
|
}
|
|
1785
2032
|
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
|
|
1786
2033
|
return (void *)ggml_backend_rpc_start_server;
|
|
@@ -1807,30 +2054,65 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
|
1807
2054
|
return &ggml_backend_rpc_reg;
|
|
1808
2055
|
}
|
|
1809
2056
|
|
|
1810
|
-
|
|
1811
|
-
|
|
2057
|
+
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
|
|
2058
|
+
auto sock = get_socket(endpoint);
|
|
2059
|
+
if (sock == nullptr) {
|
|
2060
|
+
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
|
2061
|
+
return 0;
|
|
2062
|
+
}
|
|
2063
|
+
rpc_msg_device_count_rsp response;
|
|
2064
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
|
|
2065
|
+
RPC_STATUS_ASSERT(status);
|
|
2066
|
+
return response.device_count;
|
|
2067
|
+
}
|
|
2068
|
+
|
|
2069
|
+
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
|
|
2070
|
+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
|
2071
|
+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
|
2072
|
+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
|
2073
|
+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
|
2074
|
+
};
|
|
1812
2075
|
|
|
2076
|
+
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
|
|
2077
|
+
static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
|
|
1813
2078
|
static std::mutex mutex;
|
|
2079
|
+
static uint32_t dev_id = 0;
|
|
1814
2080
|
std::lock_guard<std::mutex> lock(mutex);
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
return dev_map[endpoint];
|
|
2081
|
+
if (reg_map.find(endpoint) != reg_map.end()) {
|
|
2082
|
+
return reg_map[endpoint];
|
|
1818
2083
|
}
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
2084
|
+
uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
|
|
2085
|
+
if (dev_count == 0) {
|
|
2086
|
+
return nullptr;
|
|
2087
|
+
}
|
|
2088
|
+
ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
|
|
2089
|
+
ctx->name = "RPC[" + std::string(endpoint) + "]";
|
|
2090
|
+
for (uint32_t ind = 0; ind < dev_count; ind++) {
|
|
2091
|
+
std::string dev_name = "RPC" + std::to_string(dev_id);
|
|
2092
|
+
std::string dev_desc = std::string(endpoint);
|
|
2093
|
+
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
|
|
2094
|
+
/* .endpoint = */ endpoint,
|
|
2095
|
+
/* .device = */ ind,
|
|
2096
|
+
/* .name = */ dev_name,
|
|
2097
|
+
/* .description = */ dev_desc
|
|
2098
|
+
};
|
|
2099
|
+
|
|
2100
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
|
2101
|
+
/* .iface = */ ggml_backend_rpc_device_i,
|
|
2102
|
+
/* .reg = */ ggml_backend_rpc_reg(),
|
|
2103
|
+
/* .context = */ dev_ctx,
|
|
2104
|
+
};
|
|
2105
|
+
ctx->devices.push_back(dev);
|
|
2106
|
+
dev_id++;
|
|
2107
|
+
}
|
|
2108
|
+
ggml_backend_reg_t reg = new ggml_backend_reg {
|
|
2109
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
2110
|
+
/* .iface = */ ggml_backend_rpc_reg_interface,
|
|
2111
|
+
/* .context = */ ctx
|
|
1829
2112
|
};
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
return dev;
|
|
2113
|
+
reg_map[endpoint] = reg;
|
|
2114
|
+
return reg;
|
|
1834
2115
|
}
|
|
1835
2116
|
|
|
2117
|
+
|
|
1836
2118
|
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
|