@fugood/llama.node 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +17 -7
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +89 -27
- package/src/LlamaContext.h +2 -0
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +240 -168
- package/src/llama.cpp/.github/workflows/docker.yml +8 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +14 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -4
- package/src/llama.cpp/common/arg.cpp +986 -770
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +212 -351
- package/src/llama.cpp/common/common.h +204 -117
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +163 -121
- package/src/llama.cpp/common/sampling.h +41 -20
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +134 -161
- package/src/llama.cpp/examples/CMakeLists.txt +33 -14
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +19 -18
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +41 -87
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +263 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +83 -22
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
- package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +73 -114
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
- package/src/llama.cpp/examples/server/server.cpp +2073 -1339
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +354 -277
- package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/simple/simple.cpp +130 -94
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
- package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +159 -417
- package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +93 -52
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +4 -8
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +779 -194
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +55 -10
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +4317 -2979
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -38
- package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
- package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +62 -20
- package/src/llama.cpp/tests/test-sampling.cpp +163 -138
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -98,8 +98,8 @@ struct ring_buffer {
|
|
|
98
98
|
std::vector<T> data;
|
|
99
99
|
};
|
|
100
100
|
|
|
101
|
-
struct
|
|
102
|
-
|
|
101
|
+
struct common_sampler {
|
|
102
|
+
common_params_sampling params;
|
|
103
103
|
|
|
104
104
|
struct llama_sampler * grmr;
|
|
105
105
|
struct llama_sampler * chain;
|
|
@@ -125,26 +125,28 @@ struct gpt_sampler {
|
|
|
125
125
|
}
|
|
126
126
|
};
|
|
127
127
|
|
|
128
|
-
std::string
|
|
128
|
+
std::string common_params_sampling::print() const {
|
|
129
129
|
char result[1024];
|
|
130
130
|
|
|
131
131
|
snprintf(result, sizeof(result),
|
|
132
132
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
133
|
-
"\
|
|
133
|
+
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
|
134
|
+
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
|
134
135
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
135
136
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
136
|
-
|
|
137
|
+
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
|
138
|
+
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
|
137
139
|
mirostat, mirostat_eta, mirostat_tau);
|
|
138
140
|
|
|
139
141
|
return std::string(result);
|
|
140
142
|
}
|
|
141
143
|
|
|
142
|
-
struct
|
|
144
|
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
|
143
145
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
|
144
146
|
|
|
145
147
|
lparams.no_perf = params.no_perf;
|
|
146
148
|
|
|
147
|
-
auto * result = new
|
|
149
|
+
auto * result = new common_sampler {
|
|
148
150
|
/* .params = */ params,
|
|
149
151
|
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
|
150
152
|
/* .chain = */ llama_sampler_chain_init(lparams),
|
|
@@ -159,72 +161,63 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|
|
159
161
|
params.logit_bias.size(),
|
|
160
162
|
params.logit_bias.data()));
|
|
161
163
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
164
|
+
if (params.mirostat == 0) {
|
|
165
|
+
for (const auto & cnstr : params.samplers) {
|
|
166
|
+
switch (cnstr) {
|
|
167
|
+
case COMMON_SAMPLER_TYPE_DRY:
|
|
168
|
+
{
|
|
169
|
+
std::vector<const char *> c_breakers;
|
|
170
|
+
c_breakers.reserve(params.dry_sequence_breakers.size());
|
|
171
|
+
for (const auto & str : params.dry_sequence_breakers) {
|
|
172
|
+
c_breakers.push_back(str.c_str());
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
176
|
+
}
|
|
177
|
+
break;
|
|
178
|
+
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
179
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
|
180
|
+
break;
|
|
181
|
+
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
182
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
183
|
+
break;
|
|
184
|
+
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
185
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
186
|
+
break;
|
|
187
|
+
case COMMON_SAMPLER_TYPE_XTC:
|
|
188
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
189
|
+
break;
|
|
190
|
+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
191
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
192
|
+
break;
|
|
193
|
+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
194
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
195
|
+
break;
|
|
196
|
+
case COMMON_SAMPLER_TYPE_INFILL:
|
|
197
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
|
198
|
+
break;
|
|
199
|
+
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
200
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
201
|
+
break;
|
|
202
|
+
default:
|
|
203
|
+
GGML_ASSERT(false && "unknown sampler type");
|
|
199
204
|
}
|
|
200
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
201
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
202
|
-
} else if (params.mirostat == 1) {
|
|
203
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
204
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
205
|
-
} else if (params.mirostat == 2) {
|
|
206
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
207
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
208
|
-
} else {
|
|
209
|
-
GGML_ASSERT(false && "unknown mirostat version");
|
|
210
205
|
}
|
|
206
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
207
|
+
} else if (params.mirostat == 1) {
|
|
208
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
209
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
210
|
+
} else if (params.mirostat == 2) {
|
|
211
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
212
|
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
211
213
|
} else {
|
|
212
|
-
|
|
213
|
-
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
|
214
|
-
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
|
215
|
-
//
|
|
216
|
-
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
|
217
|
-
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
|
218
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
|
219
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
220
|
-
}
|
|
221
|
-
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
|
214
|
+
GGML_ASSERT(false && "unknown mirostat version");
|
|
222
215
|
}
|
|
223
216
|
|
|
224
217
|
return result;
|
|
225
218
|
}
|
|
226
219
|
|
|
227
|
-
void
|
|
220
|
+
void common_sampler_free(struct common_sampler * gsmpl) {
|
|
228
221
|
if (gsmpl) {
|
|
229
222
|
llama_sampler_free(gsmpl->grmr);
|
|
230
223
|
|
|
@@ -234,7 +227,7 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
|
|
234
227
|
}
|
|
235
228
|
}
|
|
236
229
|
|
|
237
|
-
void
|
|
230
|
+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
238
231
|
if (accept_grammar) {
|
|
239
232
|
llama_sampler_accept(gsmpl->grmr, token);
|
|
240
233
|
}
|
|
@@ -244,14 +237,14 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool acce
|
|
|
244
237
|
gsmpl->prev.push_back(token);
|
|
245
238
|
}
|
|
246
239
|
|
|
247
|
-
void
|
|
240
|
+
void common_sampler_reset(struct common_sampler * gsmpl) {
|
|
248
241
|
llama_sampler_reset(gsmpl->grmr);
|
|
249
242
|
|
|
250
243
|
llama_sampler_reset(gsmpl->chain);
|
|
251
244
|
}
|
|
252
245
|
|
|
253
|
-
struct
|
|
254
|
-
return new
|
|
246
|
+
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|
247
|
+
return new common_sampler {
|
|
255
248
|
/* .params = */ gsmpl->params,
|
|
256
249
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
|
257
250
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
@@ -261,7 +254,7 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
|
|
261
254
|
};
|
|
262
255
|
}
|
|
263
256
|
|
|
264
|
-
void
|
|
257
|
+
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
|
265
258
|
// TODO: measure grammar performance
|
|
266
259
|
|
|
267
260
|
if (gsmpl) {
|
|
@@ -272,7 +265,7 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
|
|
272
265
|
}
|
|
273
266
|
}
|
|
274
267
|
|
|
275
|
-
llama_token
|
|
268
|
+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
|
276
269
|
gsmpl->set_logits(ctx, idx);
|
|
277
270
|
|
|
278
271
|
auto & grmr = gsmpl->grmr;
|
|
@@ -318,21 +311,60 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
|
|
318
311
|
return cur_p.data[cur_p.selected].id;
|
|
319
312
|
}
|
|
320
313
|
|
|
321
|
-
|
|
314
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
|
315
|
+
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
|
316
|
+
|
|
317
|
+
std::vector<llama_token> result;
|
|
318
|
+
result.reserve(idxs.size());
|
|
319
|
+
|
|
320
|
+
size_t i = 0;
|
|
321
|
+
for (; i < draft.size(); i++) {
|
|
322
|
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
323
|
+
|
|
324
|
+
common_sampler_accept(gsmpl, id, true);
|
|
325
|
+
|
|
326
|
+
result.push_back(id);
|
|
327
|
+
|
|
328
|
+
if (draft[i] != id) {
|
|
329
|
+
break;
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
if (i == draft.size()) {
|
|
334
|
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
335
|
+
|
|
336
|
+
common_sampler_accept(gsmpl, id, true);
|
|
337
|
+
|
|
338
|
+
result.push_back(id);
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
return result;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
|
345
|
+
std::vector<int> idxs(draft.size() + 1);
|
|
346
|
+
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
347
|
+
idxs[i] = i;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|
322
354
|
return llama_sampler_get_seed(gsmpl->chain);
|
|
323
355
|
}
|
|
324
356
|
|
|
325
357
|
// helpers
|
|
326
358
|
|
|
327
|
-
llama_token_data_array *
|
|
359
|
+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
|
|
328
360
|
return &gsmpl->cur_p;
|
|
329
361
|
}
|
|
330
362
|
|
|
331
|
-
llama_token
|
|
363
|
+
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
|
332
364
|
return gsmpl->prev.rat(0);
|
|
333
365
|
}
|
|
334
366
|
|
|
335
|
-
std::string
|
|
367
|
+
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
|
336
368
|
std::string result = "logits ";
|
|
337
369
|
|
|
338
370
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
|
@@ -343,7 +375,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
|
|
343
375
|
return result;
|
|
344
376
|
}
|
|
345
377
|
|
|
346
|
-
std::string
|
|
378
|
+
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
|
|
347
379
|
n = std::min(n, (int) gsmpl->prev.size());
|
|
348
380
|
|
|
349
381
|
if (n <= 0) {
|
|
@@ -358,63 +390,70 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main,
|
|
|
358
390
|
|
|
359
391
|
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
|
360
392
|
|
|
361
|
-
result +=
|
|
393
|
+
result += common_token_to_piece(ctx_main, id);
|
|
362
394
|
}
|
|
363
395
|
|
|
364
396
|
return result;
|
|
365
397
|
}
|
|
366
398
|
|
|
367
|
-
char
|
|
399
|
+
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|
368
400
|
switch (cnstr) {
|
|
369
|
-
case
|
|
370
|
-
case
|
|
371
|
-
case
|
|
372
|
-
case
|
|
373
|
-
case
|
|
374
|
-
case
|
|
401
|
+
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
|
402
|
+
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
|
403
|
+
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
|
404
|
+
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
|
405
|
+
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
|
406
|
+
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
407
|
+
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
|
408
|
+
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
|
409
|
+
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
|
375
410
|
default : return '?';
|
|
376
411
|
}
|
|
377
412
|
}
|
|
378
413
|
|
|
379
|
-
std::string
|
|
414
|
+
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|
380
415
|
switch (cnstr) {
|
|
381
|
-
case
|
|
382
|
-
case
|
|
383
|
-
case
|
|
384
|
-
case
|
|
385
|
-
case
|
|
386
|
-
case
|
|
416
|
+
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
|
417
|
+
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
|
418
|
+
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
|
419
|
+
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
|
420
|
+
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
|
421
|
+
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
|
422
|
+
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
|
423
|
+
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
|
424
|
+
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
|
387
425
|
default : return "";
|
|
388
426
|
}
|
|
389
427
|
}
|
|
390
428
|
|
|
391
|
-
std::vector<
|
|
392
|
-
std::unordered_map<std::string,
|
|
393
|
-
{ "
|
|
394
|
-
{ "
|
|
395
|
-
{ "
|
|
396
|
-
{ "
|
|
397
|
-
{ "
|
|
398
|
-
{ "temperature",
|
|
429
|
+
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
|
430
|
+
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
|
431
|
+
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
|
432
|
+
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
433
|
+
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
434
|
+
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
435
|
+
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
436
|
+
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
437
|
+
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
|
438
|
+
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
|
439
|
+
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
|
399
440
|
};
|
|
400
441
|
|
|
401
442
|
// since samplers names are written multiple ways
|
|
402
443
|
// make it ready for both system names and input names
|
|
403
|
-
std::unordered_map<std::string,
|
|
404
|
-
{ "top-k",
|
|
405
|
-
{ "top-p",
|
|
406
|
-
{ "nucleus",
|
|
407
|
-
{ "typical-p",
|
|
408
|
-
{ "typical",
|
|
409
|
-
{ "typ-p",
|
|
410
|
-
{ "typ",
|
|
411
|
-
{ "min-p",
|
|
412
|
-
{ "
|
|
413
|
-
{ "tfs", GPT_SAMPLER_TYPE_TFS_Z },
|
|
414
|
-
{ "temp", GPT_SAMPLER_TYPE_TEMPERATURE },
|
|
444
|
+
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
|
445
|
+
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
|
446
|
+
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
|
447
|
+
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
|
448
|
+
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
449
|
+
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
450
|
+
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
451
|
+
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
452
|
+
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
|
453
|
+
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
415
454
|
};
|
|
416
455
|
|
|
417
|
-
std::vector<
|
|
456
|
+
std::vector<common_sampler_type> samplers;
|
|
418
457
|
samplers.reserve(names.size());
|
|
419
458
|
|
|
420
459
|
for (const auto & name : names) {
|
|
@@ -434,17 +473,20 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std
|
|
|
434
473
|
return samplers;
|
|
435
474
|
}
|
|
436
475
|
|
|
437
|
-
std::vector<
|
|
438
|
-
std::unordered_map<char,
|
|
439
|
-
{
|
|
440
|
-
{
|
|
441
|
-
{
|
|
442
|
-
{
|
|
443
|
-
{
|
|
444
|
-
{
|
|
476
|
+
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
|
477
|
+
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
|
478
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
|
479
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
|
480
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
|
481
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
|
482
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
|
483
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
|
484
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
|
485
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
|
486
|
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
|
445
487
|
};
|
|
446
488
|
|
|
447
|
-
std::vector<
|
|
489
|
+
std::vector<common_sampler_type> samplers;
|
|
448
490
|
samplers.reserve(chars.size());
|
|
449
491
|
|
|
450
492
|
for (const auto & c : chars) {
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
#include <string>
|
|
8
8
|
#include <vector>
|
|
9
9
|
|
|
10
|
-
//
|
|
10
|
+
// common_sampler extends llama_sampler with additional functionality:
|
|
11
11
|
//
|
|
12
12
|
// - grammar support
|
|
13
13
|
// - custom sampler logic based on the parameters
|
|
@@ -23,30 +23,30 @@
|
|
|
23
23
|
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
|
24
24
|
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
|
25
25
|
//
|
|
26
|
-
// The
|
|
26
|
+
// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
|
|
27
27
|
// be moved into the core llama library.
|
|
28
28
|
//
|
|
29
|
-
// For convenience, the
|
|
29
|
+
// For convenience, the common_sampler also maintains a container with the current candidate tokens.
|
|
30
30
|
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
|
31
31
|
//
|
|
32
32
|
// TODO: measure grammar performance
|
|
33
33
|
//
|
|
34
34
|
|
|
35
|
-
struct
|
|
35
|
+
struct common_sampler;
|
|
36
36
|
|
|
37
37
|
// llama_sampler API overloads
|
|
38
38
|
|
|
39
|
-
struct
|
|
39
|
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
|
40
40
|
|
|
41
|
-
void
|
|
41
|
+
void common_sampler_free(struct common_sampler * gsmpl);
|
|
42
42
|
|
|
43
43
|
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
|
44
|
-
void
|
|
45
|
-
void
|
|
46
|
-
struct
|
|
44
|
+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
|
45
|
+
void common_sampler_reset (struct common_sampler * gsmpl);
|
|
46
|
+
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
|
47
47
|
|
|
48
48
|
// arguments can be nullptr to skip printing
|
|
49
|
-
void
|
|
49
|
+
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
|
50
50
|
|
|
51
51
|
// extended sampling implementation:
|
|
52
52
|
//
|
|
@@ -58,26 +58,47 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
|
|
58
58
|
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
|
59
59
|
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
|
60
60
|
//
|
|
61
|
-
llama_token
|
|
61
|
+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
// generalized version of common_sampler_sample
|
|
64
|
+
//
|
|
65
|
+
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
|
66
|
+
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
|
67
|
+
//
|
|
68
|
+
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
|
69
|
+
//
|
|
70
|
+
// is equivalent to
|
|
71
|
+
//
|
|
72
|
+
// common_sampler_sample(gsmpl, ctx, idx);
|
|
73
|
+
// common_sampler_accept(gsmpl, token, true);
|
|
74
|
+
//
|
|
75
|
+
// requires: idxs.size() == draft.size() + 1
|
|
76
|
+
//
|
|
77
|
+
// returns at least 1 token, up to idxs.size()
|
|
78
|
+
//
|
|
79
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
|
80
|
+
|
|
81
|
+
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
|
82
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
|
83
|
+
|
|
84
|
+
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
|
64
85
|
|
|
65
86
|
// helpers
|
|
66
87
|
|
|
67
88
|
// access the internal list of current candidate tokens
|
|
68
|
-
llama_token_data_array *
|
|
89
|
+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
|
|
69
90
|
|
|
70
91
|
// get the last accepted token
|
|
71
|
-
llama_token
|
|
92
|
+
llama_token common_sampler_last(const struct common_sampler * gsmpl);
|
|
72
93
|
|
|
73
94
|
// print the sampler chain into a string
|
|
74
|
-
std::string
|
|
95
|
+
std::string common_sampler_print(const struct common_sampler * gsmpl);
|
|
75
96
|
|
|
76
97
|
// get a string representation of the last accepted tokens
|
|
77
|
-
std::string
|
|
98
|
+
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
|
|
78
99
|
|
|
79
|
-
char
|
|
80
|
-
std::string
|
|
100
|
+
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
|
101
|
+
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
|
81
102
|
|
|
82
|
-
std::vector<enum
|
|
83
|
-
std::vector<enum
|
|
103
|
+
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
|
104
|
+
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|