@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -19,12 +19,6 @@
|
|
|
19
19
|
#define GROUP_MAX_EPS_IQ1_M 1e-7f
|
|
20
20
|
#define GROUP_MAX_EPS_IQ1_S 1e-12f
|
|
21
21
|
|
|
22
|
-
#if defined(_MSC_VER)
|
|
23
|
-
// disable "possible loss of data" to avoid warnings for hundreds of casts
|
|
24
|
-
// we should just be careful :)
|
|
25
|
-
#pragma warning(disable: 4244 4267)
|
|
26
|
-
#endif
|
|
27
|
-
|
|
28
22
|
#define UNUSED GGML_UNUSED
|
|
29
23
|
|
|
30
24
|
// reference implementation for deterministic creation of model files
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#include "ggml-rpc.h"
|
|
2
2
|
#include "ggml-impl.h"
|
|
3
3
|
#include "ggml-backend-impl.h"
|
|
4
|
+
#include "ggml-cpp.h"
|
|
4
5
|
|
|
5
6
|
#include <cinttypes>
|
|
6
7
|
#include <string>
|
|
@@ -26,6 +27,10 @@
|
|
|
26
27
|
# include <unistd.h>
|
|
27
28
|
#endif
|
|
28
29
|
#include <cstring>
|
|
30
|
+
#include <fstream>
|
|
31
|
+
#include <filesystem>
|
|
32
|
+
|
|
33
|
+
namespace fs = std::filesystem;
|
|
29
34
|
|
|
30
35
|
#ifdef _WIN32
|
|
31
36
|
typedef SOCKET sockfd_t;
|
|
@@ -80,15 +85,26 @@ enum rpc_cmd {
|
|
|
80
85
|
RPC_CMD_FREE_BUFFER,
|
|
81
86
|
RPC_CMD_BUFFER_CLEAR,
|
|
82
87
|
RPC_CMD_SET_TENSOR,
|
|
88
|
+
RPC_CMD_SET_TENSOR_HASH,
|
|
83
89
|
RPC_CMD_GET_TENSOR,
|
|
84
90
|
RPC_CMD_COPY_TENSOR,
|
|
85
91
|
RPC_CMD_GRAPH_COMPUTE,
|
|
86
92
|
RPC_CMD_GET_DEVICE_MEMORY,
|
|
87
93
|
RPC_CMD_INIT_TENSOR,
|
|
88
94
|
RPC_CMD_GET_ALLOC_SIZE,
|
|
95
|
+
RPC_CMD_HELLO,
|
|
89
96
|
RPC_CMD_COUNT,
|
|
90
97
|
};
|
|
91
98
|
|
|
99
|
+
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
|
100
|
+
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
|
101
|
+
|
|
102
|
+
struct rpc_msg_hello_rsp {
|
|
103
|
+
uint8_t major;
|
|
104
|
+
uint8_t minor;
|
|
105
|
+
uint8_t patch;
|
|
106
|
+
};
|
|
107
|
+
|
|
92
108
|
struct rpc_msg_get_alloc_size_req {
|
|
93
109
|
rpc_tensor tensor;
|
|
94
110
|
};
|
|
@@ -135,6 +151,16 @@ struct rpc_msg_buffer_clear_req {
|
|
|
135
151
|
uint8_t value;
|
|
136
152
|
};
|
|
137
153
|
|
|
154
|
+
struct rpc_msg_set_tensor_hash_req {
|
|
155
|
+
rpc_tensor tensor;
|
|
156
|
+
uint64_t offset;
|
|
157
|
+
uint64_t hash;
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
struct rpc_msg_set_tensor_hash_rsp {
|
|
161
|
+
uint8_t result;
|
|
162
|
+
};
|
|
163
|
+
|
|
138
164
|
struct rpc_msg_get_tensor_req {
|
|
139
165
|
rpc_tensor tensor;
|
|
140
166
|
uint64_t offset;
|
|
@@ -187,6 +213,18 @@ struct ggml_backend_rpc_buffer_context {
|
|
|
187
213
|
|
|
188
214
|
// RPC helper functions
|
|
189
215
|
|
|
216
|
+
// Computes FNV-1a hash of the data
|
|
217
|
+
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
|
|
218
|
+
const uint64_t fnv_prime = 0x100000001b3ULL;
|
|
219
|
+
uint64_t hash = 0xcbf29ce484222325ULL;
|
|
220
|
+
|
|
221
|
+
for (size_t i = 0; i < len; ++i) {
|
|
222
|
+
hash ^= data[i];
|
|
223
|
+
hash *= fnv_prime;
|
|
224
|
+
}
|
|
225
|
+
return hash;
|
|
226
|
+
}
|
|
227
|
+
|
|
190
228
|
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
|
|
191
229
|
#ifdef _WIN32
|
|
192
230
|
if (fd == INVALID_SOCKET) {
|
|
@@ -346,8 +384,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
|
|
|
346
384
|
}
|
|
347
385
|
|
|
348
386
|
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
|
349
|
-
//
|
|
350
|
-
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size
|
|
387
|
+
// No response
|
|
388
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
|
|
351
389
|
uint8_t cmd_byte = cmd;
|
|
352
390
|
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
|
353
391
|
return false;
|
|
@@ -358,6 +396,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
358
396
|
if (!send_data(sock->fd, input, input_size)) {
|
|
359
397
|
return false;
|
|
360
398
|
}
|
|
399
|
+
return true;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
|
403
|
+
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
|
404
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
|
405
|
+
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
|
|
406
|
+
return false;
|
|
407
|
+
}
|
|
361
408
|
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
|
362
409
|
// even if we do, we can skip sending output_size from the server for commands with known output size
|
|
363
410
|
uint64_t out_size;
|
|
@@ -375,6 +422,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
375
422
|
|
|
376
423
|
// RPC client-side implementation
|
|
377
424
|
|
|
425
|
+
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
|
426
|
+
rpc_msg_hello_rsp response;
|
|
427
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
|
428
|
+
GGML_ASSERT(status);
|
|
429
|
+
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
|
430
|
+
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
431
|
+
return false;
|
|
432
|
+
}
|
|
433
|
+
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
|
434
|
+
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
435
|
+
}
|
|
436
|
+
return true;
|
|
437
|
+
}
|
|
438
|
+
|
|
378
439
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
379
440
|
static std::mutex mutex;
|
|
380
441
|
std::lock_guard<std::mutex> lock(mutex);
|
|
@@ -408,6 +469,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
408
469
|
if (sock == nullptr) {
|
|
409
470
|
return nullptr;
|
|
410
471
|
}
|
|
472
|
+
if (!check_server_version(sock)) {
|
|
473
|
+
return nullptr;
|
|
474
|
+
}
|
|
411
475
|
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
|
412
476
|
sockets[endpoint] = sock;
|
|
413
477
|
return sock;
|
|
@@ -460,6 +524,11 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
|
460
524
|
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
|
461
525
|
result.view_offs = tensor->view_offs;
|
|
462
526
|
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
|
527
|
+
|
|
528
|
+
// Avoid sending uninitialized data over the wire
|
|
529
|
+
memset(result.name, 0, sizeof(result.name));
|
|
530
|
+
memset(result.padding, 0, sizeof(result.padding));
|
|
531
|
+
|
|
463
532
|
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
|
|
464
533
|
return result;
|
|
465
534
|
}
|
|
@@ -483,14 +552,27 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
|
|
|
483
552
|
|
|
484
553
|
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
485
554
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
486
|
-
|
|
555
|
+
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
556
|
+
if (size > HASH_THRESHOLD) {
|
|
557
|
+
rpc_msg_set_tensor_hash_req request;
|
|
558
|
+
request.tensor = rpc_tensor;
|
|
559
|
+
request.offset = offset;
|
|
560
|
+
request.hash = fnv_hash((const uint8_t*)data, size);
|
|
561
|
+
rpc_msg_set_tensor_hash_rsp response;
|
|
562
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
|
563
|
+
GGML_ASSERT(status);
|
|
564
|
+
if (response.result) {
|
|
565
|
+
// the server has the same data, no need to send it
|
|
566
|
+
return;
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
|
487
570
|
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
|
488
571
|
std::vector<uint8_t> input(input_size, 0);
|
|
489
|
-
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
490
572
|
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
491
573
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
492
574
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
493
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()
|
|
575
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
|
|
494
576
|
GGML_ASSERT(status);
|
|
495
577
|
}
|
|
496
578
|
|
|
@@ -772,9 +854,12 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
|
|
|
772
854
|
|
|
773
855
|
class rpc_server {
|
|
774
856
|
public:
|
|
775
|
-
rpc_server(ggml_backend_t backend
|
|
857
|
+
rpc_server(ggml_backend_t backend, const char * cache_dir)
|
|
858
|
+
: backend(backend), cache_dir(cache_dir) {
|
|
859
|
+
}
|
|
776
860
|
~rpc_server();
|
|
777
861
|
|
|
862
|
+
void hello(rpc_msg_hello_rsp & response);
|
|
778
863
|
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
779
864
|
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
|
780
865
|
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
@@ -782,6 +867,7 @@ public:
|
|
|
782
867
|
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
783
868
|
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
784
869
|
bool set_tensor(const std::vector<uint8_t> & input);
|
|
870
|
+
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
|
785
871
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
786
872
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
787
873
|
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
|
@@ -789,6 +875,7 @@ public:
|
|
|
789
875
|
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
|
790
876
|
|
|
791
877
|
private:
|
|
878
|
+
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
|
792
879
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
|
793
880
|
ggml_tensor * create_node(uint64_t id,
|
|
794
881
|
struct ggml_context * ctx,
|
|
@@ -797,9 +884,17 @@ private:
|
|
|
797
884
|
|
|
798
885
|
|
|
799
886
|
ggml_backend_t backend;
|
|
887
|
+
const char * cache_dir;
|
|
800
888
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
801
889
|
};
|
|
802
890
|
|
|
891
|
+
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
|
892
|
+
response.major = RPC_PROTO_MAJOR_VERSION;
|
|
893
|
+
response.minor = RPC_PROTO_MINOR_VERSION;
|
|
894
|
+
response.patch = RPC_PROTO_PATCH_VERSION;
|
|
895
|
+
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
|
896
|
+
}
|
|
897
|
+
|
|
803
898
|
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
|
804
899
|
ggml_backend_buffer_type_t buft;
|
|
805
900
|
struct ggml_init_params params {
|
|
@@ -808,12 +903,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
808
903
|
/*.no_alloc =*/ true,
|
|
809
904
|
};
|
|
810
905
|
|
|
811
|
-
|
|
906
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
907
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
908
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
812
909
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
813
910
|
|
|
814
911
|
if (tensor == nullptr) {
|
|
815
912
|
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
|
816
|
-
ggml_free(ctx);
|
|
817
913
|
return false;
|
|
818
914
|
}
|
|
819
915
|
|
|
@@ -826,7 +922,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
826
922
|
|
|
827
923
|
response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
|
|
828
924
|
|
|
829
|
-
ggml_free(ctx);
|
|
830
925
|
return true;
|
|
831
926
|
}
|
|
832
927
|
|
|
@@ -895,8 +990,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
|
|
895
990
|
}
|
|
896
991
|
|
|
897
992
|
ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
|
|
993
|
+
// Validate tensor type before using it
|
|
994
|
+
if (tensor->type >= GGML_TYPE_COUNT) {
|
|
995
|
+
GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
|
|
996
|
+
return nullptr;
|
|
997
|
+
}
|
|
998
|
+
|
|
898
999
|
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
|
|
899
1000
|
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
1001
|
+
|
|
1002
|
+
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
|
|
1003
|
+
if (result == nullptr) {
|
|
1004
|
+
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
|
|
1005
|
+
return nullptr;
|
|
1006
|
+
}
|
|
1007
|
+
|
|
900
1008
|
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
|
901
1009
|
result->nb[i] = tensor->nb[i];
|
|
902
1010
|
}
|
|
@@ -940,11 +1048,12 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
940
1048
|
/*.mem_buffer =*/ NULL,
|
|
941
1049
|
/*.no_alloc =*/ true,
|
|
942
1050
|
};
|
|
943
|
-
|
|
1051
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1052
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1053
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
944
1054
|
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
|
945
1055
|
if (tensor == nullptr) {
|
|
946
1056
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
947
|
-
ggml_free(ctx);
|
|
948
1057
|
return false;
|
|
949
1058
|
}
|
|
950
1059
|
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
|
@@ -955,13 +1064,85 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
955
1064
|
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
956
1065
|
|
|
957
1066
|
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
|
958
|
-
|
|
1067
|
+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1068
|
+
__func__, in_tensor->data, offset, size, p0, p1);
|
|
1069
|
+
return false;
|
|
959
1070
|
}
|
|
960
1071
|
}
|
|
961
1072
|
|
|
962
1073
|
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
|
|
1074
|
+
if (cache_dir && size > HASH_THRESHOLD) {
|
|
1075
|
+
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
|
1076
|
+
char hash_str[17];
|
|
1077
|
+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
|
1078
|
+
// save to cache_dir/hash_str
|
|
1079
|
+
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1080
|
+
std::ofstream ofs(cache_file, std::ios::binary);
|
|
1081
|
+
ofs.write((const char *)data, size);
|
|
1082
|
+
printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
|
1083
|
+
}
|
|
963
1084
|
ggml_backend_tensor_set(tensor, data, offset, size);
|
|
964
|
-
|
|
1085
|
+
return true;
|
|
1086
|
+
}
|
|
1087
|
+
|
|
1088
|
+
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|
1089
|
+
if (!cache_dir) {
|
|
1090
|
+
return false;
|
|
1091
|
+
}
|
|
1092
|
+
char hash_str[17];
|
|
1093
|
+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
|
1094
|
+
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1095
|
+
if (!fs::exists(cache_file)) {
|
|
1096
|
+
return false;
|
|
1097
|
+
}
|
|
1098
|
+
std::ifstream ifs(cache_file, std::ios::binary);
|
|
1099
|
+
ifs.seekg(0, std::ios::end);
|
|
1100
|
+
size_t size = ifs.tellg();
|
|
1101
|
+
ifs.seekg(0, std::ios::beg);
|
|
1102
|
+
data.resize(size);
|
|
1103
|
+
ifs.read((char *)data.data(), size);
|
|
1104
|
+
return true;
|
|
1105
|
+
}
|
|
1106
|
+
|
|
1107
|
+
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
|
|
1108
|
+
{
|
|
1109
|
+
std::vector<uint8_t> cached_file;
|
|
1110
|
+
if (!get_cached_file(request.hash, cached_file)) {
|
|
1111
|
+
response.result = 0;
|
|
1112
|
+
return true;
|
|
1113
|
+
}
|
|
1114
|
+
size_t size = cached_file.size();
|
|
1115
|
+
struct ggml_init_params params {
|
|
1116
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
1117
|
+
/*.mem_buffer =*/ NULL,
|
|
1118
|
+
/*.no_alloc =*/ true,
|
|
1119
|
+
};
|
|
1120
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1121
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1122
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1123
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
1124
|
+
if (tensor == nullptr) {
|
|
1125
|
+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1126
|
+
return false;
|
|
1127
|
+
}
|
|
1128
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
|
1129
|
+
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
|
|
1130
|
+
|
|
1131
|
+
// sanitize tensor->data
|
|
1132
|
+
{
|
|
1133
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
|
1134
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
1135
|
+
|
|
1136
|
+
if (request.tensor.data + request.offset < p0
|
|
1137
|
+
|| request.tensor.data + request.offset >= p1
|
|
1138
|
+
|| size > (p1 - request.tensor.data - request.offset)) {
|
|
1139
|
+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1140
|
+
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
|
|
1141
|
+
return false;
|
|
1142
|
+
}
|
|
1143
|
+
}
|
|
1144
|
+
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
|
|
1145
|
+
response.result = 1;
|
|
965
1146
|
return true;
|
|
966
1147
|
}
|
|
967
1148
|
|
|
@@ -971,11 +1152,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
|
|
971
1152
|
/*.mem_buffer =*/ NULL,
|
|
972
1153
|
/*.no_alloc =*/ true,
|
|
973
1154
|
};
|
|
974
|
-
|
|
1155
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1156
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1157
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
975
1158
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
976
1159
|
if (tensor == nullptr) {
|
|
977
1160
|
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
|
|
978
|
-
ggml_free(ctx);
|
|
979
1161
|
return false;
|
|
980
1162
|
}
|
|
981
1163
|
|
|
@@ -991,11 +1173,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
|
|
991
1173
|
// This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
|
|
992
1174
|
// Currently unimplemented.
|
|
993
1175
|
GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
|
|
994
|
-
ggml_free(ctx);
|
|
995
1176
|
return false;
|
|
996
1177
|
}
|
|
997
1178
|
|
|
998
|
-
ggml_free(ctx);
|
|
999
1179
|
return true;
|
|
1000
1180
|
}
|
|
1001
1181
|
|
|
@@ -1005,11 +1185,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
|
|
|
1005
1185
|
/*.mem_buffer =*/ NULL,
|
|
1006
1186
|
/*.no_alloc =*/ true,
|
|
1007
1187
|
};
|
|
1008
|
-
|
|
1188
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1189
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1190
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1009
1191
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
1010
1192
|
if (tensor == nullptr) {
|
|
1011
1193
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1012
|
-
ggml_free(ctx);
|
|
1013
1194
|
return false;
|
|
1014
1195
|
}
|
|
1015
1196
|
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
|
@@ -1022,13 +1203,14 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
|
|
|
1022
1203
|
if (request.tensor.data + request.offset < p0 ||
|
|
1023
1204
|
request.tensor.data + request.offset >= p1 ||
|
|
1024
1205
|
request.size > (p1 - request.tensor.data - request.offset)) {
|
|
1025
|
-
|
|
1206
|
+
GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1207
|
+
__func__, request.tensor.data, request.offset, request.size, p0, p1);
|
|
1208
|
+
return false;
|
|
1026
1209
|
}
|
|
1027
1210
|
}
|
|
1028
1211
|
|
|
1029
1212
|
response.resize(request.size, 0);
|
|
1030
1213
|
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
|
1031
|
-
ggml_free(ctx);
|
|
1032
1214
|
return true;
|
|
1033
1215
|
}
|
|
1034
1216
|
|
|
@@ -1038,12 +1220,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1038
1220
|
/*.mem_buffer =*/ NULL,
|
|
1039
1221
|
/*.no_alloc =*/ true,
|
|
1040
1222
|
};
|
|
1041
|
-
|
|
1223
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1224
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1225
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1226
|
+
|
|
1042
1227
|
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
|
1043
1228
|
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
|
1044
1229
|
if (src == nullptr || dst == nullptr) {
|
|
1045
1230
|
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
|
|
1046
|
-
ggml_free(ctx);
|
|
1047
1231
|
return false;
|
|
1048
1232
|
}
|
|
1049
1233
|
|
|
@@ -1061,7 +1245,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1061
1245
|
dst_data + src_size,
|
|
1062
1246
|
dst_base,
|
|
1063
1247
|
dst_base + dst_buf_sz);
|
|
1064
|
-
ggml_free(ctx);
|
|
1065
1248
|
return false;
|
|
1066
1249
|
}
|
|
1067
1250
|
|
|
@@ -1069,7 +1252,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1069
1252
|
__func__, (void*) src->buffer, (void*) dst->buffer);
|
|
1070
1253
|
|
|
1071
1254
|
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
|
1072
|
-
ggml_free(ctx);
|
|
1073
1255
|
return true;
|
|
1074
1256
|
}
|
|
1075
1257
|
|
|
@@ -1077,22 +1259,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|
|
1077
1259
|
struct ggml_context * ctx,
|
|
1078
1260
|
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
|
1079
1261
|
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
|
|
1080
|
-
if (id == 0) {
|
|
1081
|
-
return nullptr;
|
|
1082
|
-
}
|
|
1083
1262
|
if (tensor_map.find(id) != tensor_map.end()) {
|
|
1084
1263
|
return tensor_map[id];
|
|
1085
1264
|
}
|
|
1086
|
-
|
|
1265
|
+
// Safely find the tensor pointer
|
|
1266
|
+
auto it_ptr = tensor_ptrs.find(id);
|
|
1267
|
+
if (it_ptr == tensor_ptrs.end()) {
|
|
1268
|
+
return nullptr;
|
|
1269
|
+
}
|
|
1270
|
+
const rpc_tensor * tensor = it_ptr->second;
|
|
1271
|
+
|
|
1087
1272
|
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
|
|
1088
1273
|
if (result == nullptr) {
|
|
1089
1274
|
return nullptr;
|
|
1090
1275
|
}
|
|
1091
1276
|
tensor_map[id] = result;
|
|
1092
1277
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
1093
|
-
|
|
1278
|
+
// Check if the source ID is 0 before calling create_node recursively
|
|
1279
|
+
if (tensor->src[i] == 0) {
|
|
1280
|
+
result->src[i] = nullptr;
|
|
1281
|
+
} else {
|
|
1282
|
+
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
|
|
1283
|
+
// If the recursive call failed for a non-zero ID, propagate the error
|
|
1284
|
+
if (result->src[i] == nullptr) {
|
|
1285
|
+
GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
|
|
1286
|
+
__func__, i, tensor->src[i], id);
|
|
1287
|
+
// Must return nullptr to signal failure up the call stack
|
|
1288
|
+
return nullptr;
|
|
1289
|
+
}
|
|
1290
|
+
}
|
|
1291
|
+
}
|
|
1292
|
+
|
|
1293
|
+
// Handle view_src similarly
|
|
1294
|
+
if (tensor->view_src == 0) {
|
|
1295
|
+
result->view_src = nullptr;
|
|
1296
|
+
} else {
|
|
1297
|
+
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
|
1298
|
+
// If the recursive call failed for a non-zero ID, propagate the error
|
|
1299
|
+
if (result->view_src == nullptr) {
|
|
1300
|
+
GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
|
|
1301
|
+
__func__, tensor->view_src, id);
|
|
1302
|
+
// Must return nullptr to signal failure up the call stack
|
|
1303
|
+
return nullptr;
|
|
1304
|
+
}
|
|
1094
1305
|
}
|
|
1095
|
-
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
|
1096
1306
|
result->view_offs = tensor->view_offs;
|
|
1097
1307
|
return result;
|
|
1098
1308
|
}
|
|
@@ -1118,12 +1328,15 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1118
1328
|
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
|
1119
1329
|
|
|
1120
1330
|
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
|
1331
|
+
|
|
1121
1332
|
struct ggml_init_params params = {
|
|
1122
1333
|
/*.mem_size =*/ buf_size,
|
|
1123
1334
|
/*.mem_buffer =*/ NULL,
|
|
1124
1335
|
/*.no_alloc =*/ true,
|
|
1125
1336
|
};
|
|
1126
|
-
|
|
1337
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1338
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1339
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1127
1340
|
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
|
1128
1341
|
graph->n_nodes = n_nodes;
|
|
1129
1342
|
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
|
@@ -1135,10 +1348,17 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1135
1348
|
int64_t id;
|
|
1136
1349
|
memcpy(&id, &nodes[i], sizeof(id));
|
|
1137
1350
|
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
|
1351
|
+
|
|
1352
|
+
// Check if create_node failed for a *non-zero* ID.
|
|
1353
|
+
// If id was 0, create_node returning nullptr is expected.
|
|
1354
|
+
// If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
|
|
1355
|
+
if (graph->nodes[i] == nullptr && id != 0) {
|
|
1356
|
+
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
|
|
1357
|
+
return false;
|
|
1358
|
+
}
|
|
1138
1359
|
}
|
|
1139
1360
|
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
|
1140
1361
|
response.result = status;
|
|
1141
|
-
ggml_free(ctx);
|
|
1142
1362
|
return true;
|
|
1143
1363
|
}
|
|
1144
1364
|
|
|
@@ -1148,10 +1368,27 @@ rpc_server::~rpc_server() {
|
|
|
1148
1368
|
}
|
|
1149
1369
|
}
|
|
1150
1370
|
|
|
1151
|
-
static void rpc_serve_client(ggml_backend_t backend,
|
|
1152
|
-
|
|
1371
|
+
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
1372
|
+
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
|
1373
|
+
rpc_server server(backend, cache_dir);
|
|
1374
|
+
uint8_t cmd;
|
|
1375
|
+
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1376
|
+
return;
|
|
1377
|
+
}
|
|
1378
|
+
// the first command sent by the client must be HELLO
|
|
1379
|
+
if (cmd != RPC_CMD_HELLO) {
|
|
1380
|
+
fprintf(stderr, "Expected HELLO command, update client\n");
|
|
1381
|
+
return;
|
|
1382
|
+
}
|
|
1383
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1384
|
+
return;
|
|
1385
|
+
}
|
|
1386
|
+
rpc_msg_hello_rsp response;
|
|
1387
|
+
server.hello(response);
|
|
1388
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1389
|
+
return;
|
|
1390
|
+
}
|
|
1153
1391
|
while (true) {
|
|
1154
|
-
uint8_t cmd;
|
|
1155
1392
|
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1156
1393
|
break;
|
|
1157
1394
|
}
|
|
@@ -1161,6 +1398,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1161
1398
|
break;
|
|
1162
1399
|
}
|
|
1163
1400
|
switch (cmd) {
|
|
1401
|
+
case RPC_CMD_HELLO: {
|
|
1402
|
+
// HELLO command is handled above
|
|
1403
|
+
return;
|
|
1404
|
+
}
|
|
1164
1405
|
case RPC_CMD_ALLOC_BUFFER: {
|
|
1165
1406
|
rpc_msg_alloc_buffer_req request;
|
|
1166
1407
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
@@ -1179,7 +1420,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1179
1420
|
return;
|
|
1180
1421
|
}
|
|
1181
1422
|
rpc_msg_get_alloc_size_rsp response;
|
|
1182
|
-
server.get_alloc_size(request, response)
|
|
1423
|
+
if (!server.get_alloc_size(request, response)) {
|
|
1424
|
+
return;
|
|
1425
|
+
}
|
|
1183
1426
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1184
1427
|
return;
|
|
1185
1428
|
}
|
|
@@ -1255,7 +1498,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1255
1498
|
if (!server.set_tensor(input)) {
|
|
1256
1499
|
return;
|
|
1257
1500
|
}
|
|
1258
|
-
|
|
1501
|
+
break;
|
|
1502
|
+
}
|
|
1503
|
+
case RPC_CMD_SET_TENSOR_HASH: {
|
|
1504
|
+
rpc_msg_set_tensor_hash_req request;
|
|
1505
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1506
|
+
return;
|
|
1507
|
+
}
|
|
1508
|
+
rpc_msg_set_tensor_hash_rsp response;
|
|
1509
|
+
if (!server.set_tensor_hash(request, response)) {
|
|
1510
|
+
return;
|
|
1511
|
+
}
|
|
1512
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1259
1513
|
return;
|
|
1260
1514
|
}
|
|
1261
1515
|
break;
|
|
@@ -1335,7 +1589,17 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1335
1589
|
}
|
|
1336
1590
|
}
|
|
1337
1591
|
|
|
1338
|
-
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
|
1592
|
+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
|
1593
|
+
const char * cache_dir,
|
|
1594
|
+
size_t free_mem, size_t total_mem) {
|
|
1595
|
+
printf("Starting RPC server v%d.%d.%d\n",
|
|
1596
|
+
RPC_PROTO_MAJOR_VERSION,
|
|
1597
|
+
RPC_PROTO_MINOR_VERSION,
|
|
1598
|
+
RPC_PROTO_PATCH_VERSION);
|
|
1599
|
+
printf(" endpoint : %s\n", endpoint);
|
|
1600
|
+
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
|
1601
|
+
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
|
|
1602
|
+
|
|
1339
1603
|
std::string host;
|
|
1340
1604
|
int port;
|
|
1341
1605
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
@@ -1364,7 +1628,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
|
|
|
1364
1628
|
}
|
|
1365
1629
|
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
|
1366
1630
|
fflush(stdout);
|
|
1367
|
-
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
|
1631
|
+
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
|
|
1368
1632
|
printf("Client connection closed\n");
|
|
1369
1633
|
fflush(stdout);
|
|
1370
1634
|
}
|
|
@@ -1495,6 +1759,9 @@ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const ch
|
|
|
1495
1759
|
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
|
1496
1760
|
return (void *)ggml_backend_rpc_add_device;
|
|
1497
1761
|
}
|
|
1762
|
+
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
|
|
1763
|
+
return (void *)ggml_backend_rpc_start_server;
|
|
1764
|
+
}
|
|
1498
1765
|
return NULL;
|
|
1499
1766
|
|
|
1500
1767
|
GGML_UNUSED(reg);
|