whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -4,6 +4,9 @@
|
|
|
4
4
|
#include "llama-vocab.h"
|
|
5
5
|
#include "llama-grammar.h"
|
|
6
6
|
|
|
7
|
+
#include "ggml-cpp.h"
|
|
8
|
+
|
|
9
|
+
#include <array>
|
|
7
10
|
#include <algorithm>
|
|
8
11
|
#include <cassert>
|
|
9
12
|
#include <cfloat>
|
|
@@ -345,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
|
|
345
348
|
|
|
346
349
|
// llama_sampler API
|
|
347
350
|
|
|
348
|
-
struct llama_sampler * llama_sampler_init(
|
|
351
|
+
struct llama_sampler * llama_sampler_init(
|
|
352
|
+
struct llama_sampler_i * iface,
|
|
353
|
+
llama_sampler_context_t ctx) {
|
|
349
354
|
return new llama_sampler {
|
|
350
355
|
/* .iface = */ iface,
|
|
351
356
|
/* .ctx = */ ctx,
|
|
@@ -361,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
|
|
361
366
|
}
|
|
362
367
|
|
|
363
368
|
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|
369
|
+
if (!smpl) {
|
|
370
|
+
return;
|
|
371
|
+
}
|
|
372
|
+
|
|
364
373
|
if (smpl->iface->accept) {
|
|
365
374
|
smpl->iface->accept(smpl, token);
|
|
366
375
|
}
|
|
367
376
|
}
|
|
368
377
|
|
|
369
378
|
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
|
379
|
+
if (!smpl) {
|
|
380
|
+
return;
|
|
381
|
+
}
|
|
382
|
+
|
|
370
383
|
GGML_ASSERT(smpl->iface->apply);
|
|
371
384
|
smpl->iface->apply(smpl, cur_p);
|
|
372
385
|
}
|
|
373
386
|
|
|
374
387
|
void llama_sampler_reset(struct llama_sampler * smpl) {
|
|
388
|
+
if (!smpl) {
|
|
389
|
+
return;
|
|
390
|
+
}
|
|
391
|
+
|
|
375
392
|
if (smpl->iface->reset) {
|
|
376
393
|
smpl->iface->reset(smpl);
|
|
377
394
|
}
|
|
378
395
|
}
|
|
379
396
|
|
|
380
397
|
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
|
398
|
+
if (!smpl) {
|
|
399
|
+
return nullptr;
|
|
400
|
+
}
|
|
401
|
+
|
|
381
402
|
if (smpl->iface->clone) {
|
|
382
403
|
return smpl->iface->clone(smpl);
|
|
383
404
|
}
|
|
@@ -404,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|
|
404
425
|
delete smpl;
|
|
405
426
|
}
|
|
406
427
|
|
|
407
|
-
|
|
408
|
-
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
428
|
+
// empty sampler
|
|
409
429
|
|
|
410
|
-
|
|
411
|
-
const
|
|
430
|
+
struct llama_sampler_empty {
|
|
431
|
+
const char * name;
|
|
432
|
+
};
|
|
412
433
|
|
|
413
|
-
|
|
434
|
+
static struct llama_sampler * llama_sampler_init_empty(const char * name);
|
|
435
|
+
|
|
436
|
+
static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
|
|
437
|
+
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
|
438
|
+
return ctx->name;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
|
|
442
|
+
GGML_UNUSED(smpl);
|
|
443
|
+
GGML_UNUSED(token);
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
447
|
+
GGML_UNUSED(smpl);
|
|
448
|
+
GGML_UNUSED(cur_p);
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
|
|
452
|
+
GGML_UNUSED(smpl);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
|
|
456
|
+
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
|
457
|
+
return llama_sampler_init_empty(ctx->name);
|
|
458
|
+
}
|
|
414
459
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
460
|
+
static void llama_sampler_empty_free(struct llama_sampler * smpl) {
|
|
461
|
+
delete (llama_sampler_empty *) smpl->ctx;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
static bool llama_sampler_empty_backend_init(
|
|
465
|
+
struct llama_sampler * smpl,
|
|
466
|
+
ggml_backend_buffer_type_t buft) {
|
|
467
|
+
GGML_UNUSED(smpl);
|
|
468
|
+
GGML_UNUSED(buft);
|
|
469
|
+
|
|
470
|
+
return true;
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
static void llama_sampler_empty_backend_accept(
|
|
474
|
+
struct llama_sampler * smpl,
|
|
475
|
+
ggml_context * ctx,
|
|
476
|
+
ggml_cgraph * gf,
|
|
477
|
+
struct ggml_tensor * selected_token) {
|
|
478
|
+
GGML_UNUSED(smpl);
|
|
479
|
+
GGML_UNUSED(ctx);
|
|
480
|
+
GGML_UNUSED(gf);
|
|
481
|
+
GGML_UNUSED(selected_token);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
static void llama_sampler_empty_backend_apply(
|
|
485
|
+
struct llama_sampler * smpl,
|
|
486
|
+
struct ggml_context * ctx,
|
|
487
|
+
struct ggml_cgraph * gf,
|
|
488
|
+
struct llama_sampler_data * data) {
|
|
489
|
+
GGML_UNUSED(smpl);
|
|
490
|
+
GGML_UNUSED(ctx);
|
|
491
|
+
GGML_UNUSED(gf);
|
|
492
|
+
GGML_UNUSED(data);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
|
|
496
|
+
GGML_UNUSED(smpl);
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
static struct llama_sampler_i llama_sampler_empty_i = {
|
|
500
|
+
/* .name = */ llama_sampler_empty_name,
|
|
501
|
+
/* .accept = */ llama_sampler_empty_accept,
|
|
502
|
+
/* .apply = */ llama_sampler_empty_apply,
|
|
503
|
+
/* .reset = */ llama_sampler_empty_reset,
|
|
504
|
+
/* .clone = */ llama_sampler_empty_clone,
|
|
505
|
+
/* .free = */ llama_sampler_empty_free,
|
|
506
|
+
/* .backend_init = */ llama_sampler_empty_backend_init,
|
|
507
|
+
/* .backend_accept = */ llama_sampler_empty_backend_accept,
|
|
508
|
+
/* .backend_apply = */ llama_sampler_empty_backend_apply,
|
|
509
|
+
/* .backend_set_input = */ llama_sampler_empty_backend_set_input,
|
|
510
|
+
};
|
|
511
|
+
|
|
512
|
+
struct llama_sampler * llama_sampler_init_empty(const char * name) {
|
|
513
|
+
return llama_sampler_init(
|
|
514
|
+
/* .iface = */ &llama_sampler_empty_i,
|
|
515
|
+
/* .ctx = */ new llama_sampler_empty {
|
|
516
|
+
/* .name = */ name,
|
|
517
|
+
}
|
|
518
|
+
);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
// common backend sampler functionality
|
|
522
|
+
//
|
|
523
|
+
// +name : means that the sampler is support and will run on the backend
|
|
524
|
+
// -name : means that a ggml operator is not supported by the backend
|
|
525
|
+
//
|
|
526
|
+
struct llama_sampler_backend {
|
|
527
|
+
llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
|
|
528
|
+
|
|
529
|
+
const char * get_name() {
|
|
530
|
+
if (!is_init) {
|
|
531
|
+
return name.c_str();
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
if (support) {
|
|
535
|
+
name_ext = "+" + name;
|
|
536
|
+
} else {
|
|
537
|
+
name_ext = "-" + name;
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
return name_ext.c_str();
|
|
420
541
|
}
|
|
421
542
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
543
|
+
void init(bool support) {
|
|
544
|
+
GGML_ASSERT(this->is_init == false);
|
|
545
|
+
|
|
546
|
+
this->is_init = true;
|
|
547
|
+
this->support = support;
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
private:
|
|
551
|
+
std::string name;
|
|
552
|
+
std::string name_ext;
|
|
553
|
+
|
|
554
|
+
bool is_init;
|
|
555
|
+
bool support;
|
|
556
|
+
};
|
|
557
|
+
|
|
558
|
+
// check if all ggml ops used by the sampler are supported by the backend
|
|
559
|
+
static bool llama_sampler_backend_support(
|
|
560
|
+
llama_sampler * smpl,
|
|
561
|
+
ggml_backend_buffer_type_t buft) {
|
|
562
|
+
auto * device = ggml_backend_buft_get_device(buft);
|
|
563
|
+
if (!device) {
|
|
564
|
+
// CPU backend always supported
|
|
565
|
+
return true;
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
ggml_init_params params = {
|
|
569
|
+
/*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
|
|
570
|
+
/*.mem_buffer =*/ NULL,
|
|
571
|
+
/*.no_alloc =*/ true,
|
|
427
572
|
};
|
|
428
573
|
|
|
429
|
-
|
|
574
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
575
|
+
if (!ctx_ptr) {
|
|
576
|
+
throw std::runtime_error(format("failed to create ggml context"));
|
|
577
|
+
}
|
|
430
578
|
|
|
431
|
-
|
|
579
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
432
580
|
|
|
433
|
-
|
|
581
|
+
const int64_t n = 1024*1024;
|
|
434
582
|
|
|
435
|
-
|
|
583
|
+
llama_sampler_data data = {
|
|
584
|
+
/*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
|
|
585
|
+
/*.probs = */ nullptr,
|
|
586
|
+
/*.sampled = */ nullptr,
|
|
587
|
+
/*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
|
|
588
|
+
};
|
|
436
589
|
|
|
437
|
-
|
|
590
|
+
ggml_cgraph * gf = ggml_new_graph(ctx);
|
|
591
|
+
|
|
592
|
+
smpl->iface->backend_apply(smpl, ctx, gf, &data);
|
|
593
|
+
|
|
594
|
+
if (data.logits) {
|
|
595
|
+
ggml_build_forward_expand(gf, data.logits);
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
if (data.probs) {
|
|
599
|
+
ggml_build_forward_expand(gf, data.probs);
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
if (data.sampled) {
|
|
603
|
+
ggml_build_forward_expand(gf, data.sampled);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
if (data.candidates) {
|
|
607
|
+
ggml_build_forward_expand(gf, data.candidates);
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
611
|
+
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
|
612
|
+
|
|
613
|
+
if (!ggml_backend_dev_supports_op(device, op)) {
|
|
614
|
+
LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
|
|
615
|
+
__func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
|
|
616
|
+
|
|
617
|
+
return false;
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
return true;
|
|
438
622
|
}
|
|
439
623
|
|
|
440
624
|
// sampler chain
|
|
@@ -448,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
|
|
|
448
632
|
|
|
449
633
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
450
634
|
|
|
451
|
-
for (auto
|
|
452
|
-
llama_sampler_accept(smpl, token);
|
|
635
|
+
for (auto & smpl : chain->samplers) {
|
|
636
|
+
llama_sampler_accept(smpl.ptr, token);
|
|
453
637
|
}
|
|
454
638
|
|
|
455
639
|
chain->n_sample++;
|
|
@@ -460,20 +644,29 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
|
|
|
460
644
|
|
|
461
645
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
462
646
|
|
|
463
|
-
|
|
464
|
-
|
|
647
|
+
bool is_backend = chain->is_init;
|
|
648
|
+
|
|
649
|
+
for (auto & smpl : chain->samplers) {
|
|
650
|
+
if (is_backend && smpl.is_backend) {
|
|
651
|
+
continue;
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
is_backend = false;
|
|
655
|
+
|
|
656
|
+
if (smpl.ptr->iface->apply == nullptr) {
|
|
657
|
+
continue;
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
llama_sampler_apply(smpl.ptr, cur_p);
|
|
465
661
|
}
|
|
466
662
|
}
|
|
467
663
|
|
|
468
664
|
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
|
469
665
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
470
666
|
|
|
471
|
-
for (auto
|
|
472
|
-
llama_sampler_reset(smpl);
|
|
667
|
+
for (auto & smpl : chain->samplers) {
|
|
668
|
+
llama_sampler_reset(smpl.ptr);
|
|
473
669
|
}
|
|
474
|
-
|
|
475
|
-
chain->t_sample_us = 0;
|
|
476
|
-
chain->n_sample = 0;
|
|
477
670
|
}
|
|
478
671
|
|
|
479
672
|
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
|
@@ -481,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
|
|
481
674
|
|
|
482
675
|
auto * result = llama_sampler_chain_init(chain_src->params);
|
|
483
676
|
|
|
484
|
-
for (auto
|
|
485
|
-
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
|
677
|
+
for (const auto & smpl : chain_src->samplers) {
|
|
678
|
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
|
|
486
679
|
}
|
|
487
680
|
|
|
488
681
|
return result;
|
|
@@ -491,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
|
|
491
684
|
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
|
492
685
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
493
686
|
|
|
494
|
-
for (auto
|
|
495
|
-
llama_sampler_free(smpl);
|
|
687
|
+
for (auto & smpl : chain->samplers) {
|
|
688
|
+
llama_sampler_free(smpl.ptr);
|
|
496
689
|
}
|
|
497
690
|
|
|
498
691
|
delete chain;
|
|
499
692
|
}
|
|
500
693
|
|
|
694
|
+
static bool llama_sampler_chain_backend_init(
|
|
695
|
+
struct llama_sampler * smpl,
|
|
696
|
+
ggml_backend_buffer_type_t buft) {
|
|
697
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
698
|
+
|
|
699
|
+
GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
|
|
700
|
+
|
|
701
|
+
chain->is_init = true;
|
|
702
|
+
|
|
703
|
+
bool res = true;
|
|
704
|
+
|
|
705
|
+
for (auto & smpl : chain->samplers) {
|
|
706
|
+
bool res_cur = true;
|
|
707
|
+
|
|
708
|
+
// to be able to run a sampler on the backend, it has to:
|
|
709
|
+
// - have the .backend_init() API implemented
|
|
710
|
+
// - return true during .backend_init()
|
|
711
|
+
if (smpl.ptr->iface->backend_init) {
|
|
712
|
+
if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
|
|
713
|
+
res_cur = false;
|
|
714
|
+
}
|
|
715
|
+
} else {
|
|
716
|
+
res_cur = false;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
smpl.is_backend = res_cur;
|
|
720
|
+
|
|
721
|
+
res = res && res_cur;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
return res;
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
static void llama_sampler_chain_backend_accept(
|
|
728
|
+
struct llama_sampler * smpl,
|
|
729
|
+
ggml_context * ctx,
|
|
730
|
+
ggml_cgraph * gf,
|
|
731
|
+
struct ggml_tensor * selected_token) {
|
|
732
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
733
|
+
|
|
734
|
+
for (auto & smpl : chain->samplers) {
|
|
735
|
+
if (!smpl.is_backend) {
|
|
736
|
+
break;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
if (smpl.ptr->iface->backend_accept) {
|
|
740
|
+
smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
|
|
741
|
+
}
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
static void llama_sampler_chain_backend_apply(
|
|
746
|
+
struct llama_sampler * smpl,
|
|
747
|
+
struct ggml_context * ctx,
|
|
748
|
+
struct ggml_cgraph * gf,
|
|
749
|
+
struct llama_sampler_data * data) {
|
|
750
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
751
|
+
|
|
752
|
+
GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
|
|
753
|
+
|
|
754
|
+
for (auto & smpl : chain->samplers) {
|
|
755
|
+
if (!smpl.is_backend) {
|
|
756
|
+
break;
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
if (smpl.ptr->iface->backend_apply) {
|
|
760
|
+
smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
|
|
766
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
767
|
+
|
|
768
|
+
for (auto & smpl : chain->samplers) {
|
|
769
|
+
if (!smpl.is_backend) {
|
|
770
|
+
break;
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
if (smpl.ptr->iface->backend_set_input) {
|
|
774
|
+
smpl.ptr->iface->backend_set_input(smpl.ptr);
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
|
|
501
779
|
static struct llama_sampler_i llama_sampler_chain_i = {
|
|
502
|
-
/* .name
|
|
503
|
-
/* .accept
|
|
504
|
-
/* .apply
|
|
505
|
-
/* .reset
|
|
506
|
-
/* .clone
|
|
507
|
-
/* .free
|
|
780
|
+
/* .name = */ llama_sampler_chain_name,
|
|
781
|
+
/* .accept = */ llama_sampler_chain_accept,
|
|
782
|
+
/* .apply = */ llama_sampler_chain_apply,
|
|
783
|
+
/* .reset = */ llama_sampler_chain_reset,
|
|
784
|
+
/* .clone = */ llama_sampler_chain_clone,
|
|
785
|
+
/* .free = */ llama_sampler_chain_free,
|
|
786
|
+
/* .backend_init = */ llama_sampler_chain_backend_init,
|
|
787
|
+
/* .backend_accept = */ llama_sampler_chain_backend_accept,
|
|
788
|
+
/* .backend_apply = */ llama_sampler_chain_backend_apply,
|
|
789
|
+
/* .backend_set_input = */ llama_sampler_chain_backend_set_input,
|
|
508
790
|
};
|
|
509
791
|
|
|
510
792
|
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
|
@@ -512,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
|
|
512
794
|
/* .iface = */ &llama_sampler_chain_i,
|
|
513
795
|
/* .ctx = */ new llama_sampler_chain {
|
|
514
796
|
/* .params = */ params,
|
|
797
|
+
/* .is_init = */ false,
|
|
515
798
|
/* .samplers = */ {},
|
|
799
|
+
/* .cur = */ {},
|
|
516
800
|
/* .t_sample_us = */ 0,
|
|
517
801
|
/* .n_sample = */ 0,
|
|
518
802
|
}
|
|
519
803
|
);
|
|
520
804
|
}
|
|
521
805
|
|
|
806
|
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
807
|
+
const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
|
|
808
|
+
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
|
809
|
+
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
|
810
|
+
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
|
811
|
+
|
|
812
|
+
// If a backend sampler has already sampled a token, return it.
|
|
813
|
+
if (sampled_token != LLAMA_TOKEN_NULL) {
|
|
814
|
+
LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
|
|
815
|
+
return sampled_token;
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
const llama_model * model = llama_get_model(ctx);
|
|
819
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
820
|
+
|
|
821
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
822
|
+
|
|
823
|
+
// use pre-allocated buffer from chain if available, otherwise allocate locally
|
|
824
|
+
std::vector<llama_token_data> * cur_ptr;
|
|
825
|
+
std::vector<llama_token_data> cur_local;
|
|
826
|
+
|
|
827
|
+
if (smpl->iface == &llama_sampler_chain_i) {
|
|
828
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
829
|
+
cur_ptr = &chain->cur;
|
|
830
|
+
} else {
|
|
831
|
+
cur_ptr = &cur_local;
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
auto & cur = *cur_ptr;
|
|
835
|
+
|
|
836
|
+
if (sampled_probs) {
|
|
837
|
+
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
|
838
|
+
cur.resize(sampled_probs_count);
|
|
839
|
+
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
|
840
|
+
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
|
841
|
+
}
|
|
842
|
+
} else if (sampled_logits) {
|
|
843
|
+
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
|
844
|
+
cur.resize(sampled_logits_count);
|
|
845
|
+
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
|
846
|
+
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
|
847
|
+
}
|
|
848
|
+
} else {
|
|
849
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
850
|
+
GGML_ASSERT(logits != nullptr);
|
|
851
|
+
cur.resize(n_vocab);
|
|
852
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
853
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
llama_token_data_array cur_p = {
|
|
858
|
+
/* .data = */ cur.data(),
|
|
859
|
+
/* .size = */ cur.size(),
|
|
860
|
+
/* .selected = */ -1,
|
|
861
|
+
/* .sorted = */ false,
|
|
862
|
+
};
|
|
863
|
+
|
|
864
|
+
llama_sampler_apply(smpl, &cur_p);
|
|
865
|
+
|
|
866
|
+
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
|
867
|
+
|
|
868
|
+
auto token = cur_p.data[cur_p.selected].id;
|
|
869
|
+
|
|
870
|
+
llama_sampler_accept(smpl, token);
|
|
871
|
+
|
|
872
|
+
return token;
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
|
|
522
876
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
523
877
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
|
524
|
-
p->samplers.push_back(
|
|
878
|
+
p->samplers.push_back({
|
|
879
|
+
/* .is_backend = */ false,
|
|
880
|
+
/* .ptr = */ smpl,
|
|
881
|
+
});
|
|
525
882
|
}
|
|
526
883
|
|
|
527
|
-
struct llama_sampler * llama_sampler_chain_get(
|
|
884
|
+
struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
|
|
885
|
+
if (chain == nullptr) {
|
|
886
|
+
return nullptr;
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
if (chain->iface != &llama_sampler_chain_i) {
|
|
890
|
+
return nullptr;
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
if (i == -1) {
|
|
894
|
+
return chain;
|
|
895
|
+
}
|
|
896
|
+
|
|
528
897
|
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
|
529
898
|
|
|
530
899
|
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
|
531
900
|
return nullptr;
|
|
532
901
|
}
|
|
533
902
|
|
|
534
|
-
return p->samplers[i];
|
|
903
|
+
return p->samplers[i].ptr;
|
|
535
904
|
}
|
|
536
905
|
|
|
537
906
|
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
|
@@ -541,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
|
|
|
541
910
|
return nullptr;
|
|
542
911
|
}
|
|
543
912
|
|
|
544
|
-
auto * result = p->samplers[i];
|
|
913
|
+
auto * result = p->samplers[i].ptr;
|
|
545
914
|
p->samplers.erase(p->samplers.begin() + i);
|
|
546
915
|
|
|
547
916
|
return result;
|
|
@@ -559,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
|
|
559
928
|
|
|
560
929
|
// greedy
|
|
561
930
|
|
|
562
|
-
|
|
563
|
-
|
|
931
|
+
struct llama_sampler_greedy : public llama_sampler_backend {
|
|
932
|
+
};
|
|
933
|
+
|
|
934
|
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
|
|
935
|
+
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
|
936
|
+
return sctx->get_name();
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
|
|
940
|
+
auto * ctx = (llama_sampler_greedy *) smpl->ctx;
|
|
941
|
+
GGML_UNUSED(ctx);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
|
|
945
|
+
const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
|
|
946
|
+
auto * result = llama_sampler_init_greedy();
|
|
947
|
+
|
|
948
|
+
// copy the state
|
|
949
|
+
{
|
|
950
|
+
auto * result_ctx = (llama_sampler_greedy *) result->ctx;
|
|
951
|
+
|
|
952
|
+
GGML_UNUSED(ctx);
|
|
953
|
+
GGML_UNUSED(result_ctx);
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
return result;
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
|
|
960
|
+
delete (llama_sampler_greedy *) smpl->ctx;
|
|
564
961
|
}
|
|
565
962
|
|
|
566
963
|
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
|
@@ -572,33 +969,72 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
|
|
572
969
|
}
|
|
573
970
|
}
|
|
574
971
|
|
|
972
|
+
static bool llama_sampler_greedy_backend_init(
|
|
973
|
+
struct llama_sampler * smpl,
|
|
974
|
+
ggml_backend_buffer_type_t buft) {
|
|
975
|
+
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
|
976
|
+
|
|
977
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
978
|
+
|
|
979
|
+
sctx->init(res);
|
|
980
|
+
|
|
981
|
+
return res;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
static void llama_sampler_greedy_backend_apply(
|
|
985
|
+
struct llama_sampler * smpl,
|
|
986
|
+
struct ggml_context * ctx,
|
|
987
|
+
struct ggml_cgraph * gf,
|
|
988
|
+
struct llama_sampler_data * data) {
|
|
989
|
+
GGML_UNUSED(gf);
|
|
990
|
+
GGML_UNUSED(smpl);
|
|
991
|
+
|
|
992
|
+
struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
|
|
993
|
+
ggml_set_name(curl, "greedy_argmax");
|
|
994
|
+
|
|
995
|
+
data->sampled = curl;
|
|
996
|
+
}
|
|
997
|
+
|
|
575
998
|
static struct llama_sampler_i llama_sampler_greedy_i = {
|
|
576
|
-
/* .name
|
|
577
|
-
/* .accept
|
|
578
|
-
/* .apply
|
|
579
|
-
/* .reset
|
|
580
|
-
/* .clone
|
|
581
|
-
/* .free
|
|
999
|
+
/* .name = */ llama_sampler_greedy_name,
|
|
1000
|
+
/* .accept = */ nullptr,
|
|
1001
|
+
/* .apply = */ llama_sampler_greedy_apply,
|
|
1002
|
+
/* .reset = */ llama_sampler_greedy_reset,
|
|
1003
|
+
/* .clone = */ llama_sampler_greedy_clone,
|
|
1004
|
+
/* .free = */ llama_sampler_greedy_free,
|
|
1005
|
+
/* .backend_init = */ llama_sampler_greedy_backend_init,
|
|
1006
|
+
/* .backend_accept = */ nullptr,
|
|
1007
|
+
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
|
|
1008
|
+
/* .backend_set_input = */ nullptr,
|
|
582
1009
|
};
|
|
583
1010
|
|
|
584
1011
|
struct llama_sampler * llama_sampler_init_greedy() {
|
|
585
1012
|
return llama_sampler_init(
|
|
586
1013
|
/* .iface = */ &llama_sampler_greedy_i,
|
|
587
|
-
/* .ctx = */
|
|
1014
|
+
/* .ctx = */ new llama_sampler_greedy {
|
|
1015
|
+
("greedy"),
|
|
1016
|
+
}
|
|
588
1017
|
);
|
|
589
1018
|
}
|
|
590
1019
|
|
|
591
1020
|
// dist
|
|
592
1021
|
|
|
593
|
-
struct llama_sampler_dist {
|
|
1022
|
+
struct llama_sampler_dist : public llama_sampler_backend {
|
|
594
1023
|
const uint32_t seed;
|
|
595
1024
|
uint32_t seed_cur;
|
|
596
1025
|
|
|
597
1026
|
std::mt19937 rng;
|
|
1027
|
+
|
|
1028
|
+
// backend input
|
|
1029
|
+
struct ggml_tensor * inp_uniform;
|
|
1030
|
+
|
|
1031
|
+
ggml_context_ptr inp_ctx;
|
|
1032
|
+
ggml_backend_buffer_ptr inp_buf;
|
|
598
1033
|
};
|
|
599
1034
|
|
|
600
|
-
static const char * llama_sampler_dist_name(const struct llama_sampler *
|
|
601
|
-
|
|
1035
|
+
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
|
|
1036
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1037
|
+
return sctx->get_name();
|
|
602
1038
|
}
|
|
603
1039
|
|
|
604
1040
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -673,6 +1109,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|
|
673
1109
|
#endif
|
|
674
1110
|
}
|
|
675
1111
|
|
|
1112
|
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
1113
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
1114
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
1115
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
1116
|
+
}
|
|
1117
|
+
|
|
676
1118
|
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
|
677
1119
|
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
|
678
1120
|
auto * result = llama_sampler_init_dist(ctx->seed);
|
|
@@ -687,23 +1129,127 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
|
|
|
687
1129
|
return result;
|
|
688
1130
|
}
|
|
689
1131
|
|
|
690
|
-
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
691
|
-
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
692
|
-
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
693
|
-
ctx->rng.seed(ctx->seed_cur);
|
|
694
|
-
}
|
|
695
|
-
|
|
696
1132
|
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|
697
1133
|
delete (llama_sampler_dist *) smpl->ctx;
|
|
698
1134
|
}
|
|
699
1135
|
|
|
1136
|
+
static bool llama_sampler_dist_backend_init(
|
|
1137
|
+
struct llama_sampler * smpl,
|
|
1138
|
+
ggml_backend_buffer_type_t buft) {
|
|
1139
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1140
|
+
|
|
1141
|
+
// allocate inputs
|
|
1142
|
+
{
|
|
1143
|
+
ggml_init_params params = {
|
|
1144
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
1145
|
+
/*.mem_buffer =*/ nullptr,
|
|
1146
|
+
/*.no_alloc =*/ true,
|
|
1147
|
+
};
|
|
1148
|
+
|
|
1149
|
+
sctx->inp_ctx.reset(ggml_init(params));
|
|
1150
|
+
|
|
1151
|
+
// Create the uniform random scalar input tensor. This will be set by
|
|
1152
|
+
// llama_sampler_dist_backend_set_input after this graph is built.
|
|
1153
|
+
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
|
|
1154
|
+
ggml_set_name (sctx->inp_uniform, "uniform");
|
|
1155
|
+
ggml_set_input(sctx->inp_uniform);
|
|
1156
|
+
|
|
1157
|
+
// Allocate all tensors from our context to the backend
|
|
1158
|
+
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
|
1159
|
+
|
|
1160
|
+
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1164
|
+
|
|
1165
|
+
sctx->init(res);
|
|
1166
|
+
|
|
1167
|
+
if (!res) {
|
|
1168
|
+
sctx->inp_ctx.reset(nullptr);
|
|
1169
|
+
sctx->inp_buf.reset(nullptr);
|
|
1170
|
+
}
|
|
1171
|
+
|
|
1172
|
+
return res;
|
|
1173
|
+
}
|
|
1174
|
+
|
|
1175
|
+
static void llama_sampler_dist_backend_apply(
|
|
1176
|
+
struct llama_sampler * smpl,
|
|
1177
|
+
struct ggml_context * ctx,
|
|
1178
|
+
struct ggml_cgraph * gf,
|
|
1179
|
+
struct llama_sampler_data * data) {
|
|
1180
|
+
GGML_UNUSED(gf);
|
|
1181
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1182
|
+
|
|
1183
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
1184
|
+
ggml_set_name(probs, "dist_probs");
|
|
1185
|
+
|
|
1186
|
+
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
|
1187
|
+
ggml_set_name(cumsum, "dist_cumsum");
|
|
1188
|
+
|
|
1189
|
+
// The uniform tensor has a random value and we subtract this tensor with
|
|
1190
|
+
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
|
|
1191
|
+
// Recall that each entry in cumsum is the cumulative probability up to that
|
|
1192
|
+
// index so values stay negative while the cumulative total is below the
|
|
1193
|
+
// random value, and become zero/positive once the threshold is crossed.
|
|
1194
|
+
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
|
|
1195
|
+
ggml_set_name(diff, "dist_cumsum");
|
|
1196
|
+
|
|
1197
|
+
// The ggml_step function produces a tensor where entries are 1 if the
|
|
1198
|
+
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
|
|
1199
|
+
// the index where the cumulative probability exceeds the random value are 0,
|
|
1200
|
+
// and all entries after that are 1.
|
|
1201
|
+
struct ggml_tensor * mask = ggml_step(ctx, diff);
|
|
1202
|
+
ggml_set_name(mask, "dist_mask");
|
|
1203
|
+
|
|
1204
|
+
// Taking the sum of the mask gives us the sum of elements after the threshold
|
|
1205
|
+
// we are interested in.
|
|
1206
|
+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
|
1207
|
+
ggml_set_name(idxf, "dist_index_f32");
|
|
1208
|
+
|
|
1209
|
+
// Use ggml_scale_bias to scale the index value by -1 and then add the size
|
|
1210
|
+
// of the mask to that value so we get the correct index ((-1 * idxf) + n).
|
|
1211
|
+
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
|
1212
|
+
ggml_set_name(idx, "dist_index_i32");
|
|
1213
|
+
|
|
1214
|
+
// Map back to original vocab ids if a candidates tensor is available.
|
|
1215
|
+
struct ggml_tensor * sampled_token = idx;
|
|
1216
|
+
if (data->candidates != nullptr) {
|
|
1217
|
+
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
|
|
1218
|
+
|
|
1219
|
+
sampled_token = ggml_get_rows(ctx, candidates, idx);
|
|
1220
|
+
ggml_set_name(sampled_token, "dist_sampled_token");
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
data->sampled = sampled_token;
|
|
1224
|
+
data->probs = probs;
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
|
1228
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1229
|
+
GGML_ASSERT(sctx->inp_uniform != nullptr);
|
|
1230
|
+
|
|
1231
|
+
// We sample in double precision and cast to float to match rnd numbers of
|
|
1232
|
+
// llama_dampler_dist which uses double precision (sampling from
|
|
1233
|
+
// std::uniform_real_distribution<double> and
|
|
1234
|
+
// std::uniform_real_distribution<float> with same rng will produce
|
|
1235
|
+
// different sequences).
|
|
1236
|
+
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
|
1237
|
+
const float rnd = dist(sctx->rng);
|
|
1238
|
+
|
|
1239
|
+
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
|
1240
|
+
}
|
|
1241
|
+
|
|
700
1242
|
static struct llama_sampler_i llama_sampler_dist_i = {
|
|
701
|
-
/* .name
|
|
702
|
-
/* .accept
|
|
703
|
-
/* .apply
|
|
704
|
-
/* .reset
|
|
705
|
-
/* .clone
|
|
706
|
-
/* .free
|
|
1243
|
+
/* .name = */ llama_sampler_dist_name,
|
|
1244
|
+
/* .accept = */ nullptr,
|
|
1245
|
+
/* .apply = */ llama_sampler_dist_apply,
|
|
1246
|
+
/* .reset = */ llama_sampler_dist_reset,
|
|
1247
|
+
/* .clone = */ llama_sampler_dist_clone,
|
|
1248
|
+
/* .free = */ llama_sampler_dist_free,
|
|
1249
|
+
/* .backend_init = */ llama_sampler_dist_backend_init,
|
|
1250
|
+
/* .backend_accept = */ nullptr,
|
|
1251
|
+
/* .backend_apply = */ llama_sampler_dist_backend_apply,
|
|
1252
|
+
/* .backend_set_input = */ llama_sampler_dist_backend_set_input,
|
|
707
1253
|
};
|
|
708
1254
|
|
|
709
1255
|
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
@@ -711,21 +1257,26 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
|
711
1257
|
return llama_sampler_init(
|
|
712
1258
|
/* .iface = */ &llama_sampler_dist_i,
|
|
713
1259
|
/* .ctx = */ new llama_sampler_dist {
|
|
714
|
-
|
|
715
|
-
/* .
|
|
716
|
-
/* .
|
|
1260
|
+
("dist"),
|
|
1261
|
+
/* .seed = */ seed,
|
|
1262
|
+
/* .seed_cur = */ seed_cur,
|
|
1263
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
1264
|
+
/* .inp_uniform = */ nullptr,
|
|
1265
|
+
/* .inp_ctx = */ nullptr,
|
|
1266
|
+
/* .inp_buf = */ nullptr,
|
|
717
1267
|
}
|
|
718
1268
|
);
|
|
719
1269
|
}
|
|
720
1270
|
|
|
721
1271
|
// top-k
|
|
722
1272
|
|
|
723
|
-
struct llama_sampler_top_k {
|
|
1273
|
+
struct llama_sampler_top_k : public llama_sampler_backend {
|
|
724
1274
|
const int32_t k;
|
|
725
1275
|
};
|
|
726
1276
|
|
|
727
|
-
static const char * llama_sampler_top_k_name(const struct llama_sampler *
|
|
728
|
-
|
|
1277
|
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
|
|
1278
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1279
|
+
return sctx->get_name();
|
|
729
1280
|
}
|
|
730
1281
|
|
|
731
1282
|
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -742,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
|
|
742
1293
|
delete (llama_sampler_top_k *) smpl->ctx;
|
|
743
1294
|
}
|
|
744
1295
|
|
|
1296
|
+
static bool llama_sampler_top_k_backend_init(
|
|
1297
|
+
struct llama_sampler * smpl,
|
|
1298
|
+
ggml_backend_buffer_type_t buft) {
|
|
1299
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1300
|
+
|
|
1301
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1302
|
+
|
|
1303
|
+
sctx->init(res);
|
|
1304
|
+
|
|
1305
|
+
return res;
|
|
1306
|
+
}
|
|
1307
|
+
|
|
1308
|
+
static void llama_sampler_top_k_backend_apply(
|
|
1309
|
+
struct llama_sampler * smpl,
|
|
1310
|
+
struct ggml_context * ctx,
|
|
1311
|
+
struct ggml_cgraph * gf,
|
|
1312
|
+
struct llama_sampler_data * data) {
|
|
1313
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1314
|
+
|
|
1315
|
+
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
|
|
1316
|
+
ggml_set_name(top_k, "top_k");
|
|
1317
|
+
|
|
1318
|
+
if (data->candidates) {
|
|
1319
|
+
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
|
1320
|
+
data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
|
|
1321
|
+
data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
|
|
1322
|
+
ggml_set_name(data->candidates, "top_k_candidates");
|
|
1323
|
+
} else {
|
|
1324
|
+
data->candidates = top_k;
|
|
1325
|
+
}
|
|
1326
|
+
|
|
1327
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1328
|
+
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
|
1329
|
+
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
|
|
1330
|
+
ggml_set_name(top_k_rows, "top_k_rows");
|
|
1331
|
+
|
|
1332
|
+
GGML_UNUSED(gf);
|
|
1333
|
+
}
|
|
1334
|
+
|
|
745
1335
|
static struct llama_sampler_i llama_sampler_top_k_i = {
|
|
746
|
-
/* .name
|
|
747
|
-
/* .accept
|
|
748
|
-
/* .apply
|
|
749
|
-
/* .reset
|
|
750
|
-
/* .clone
|
|
751
|
-
/* .free
|
|
1336
|
+
/* .name = */ llama_sampler_top_k_name,
|
|
1337
|
+
/* .accept = */ nullptr,
|
|
1338
|
+
/* .apply = */ llama_sampler_top_k_apply,
|
|
1339
|
+
/* .reset = */ nullptr,
|
|
1340
|
+
/* .clone = */ llama_sampler_top_k_clone,
|
|
1341
|
+
/* .free = */ llama_sampler_top_k_free,
|
|
1342
|
+
/* .backend_init = */ llama_sampler_top_k_backend_init,
|
|
1343
|
+
/* .backend_accept = */ nullptr,
|
|
1344
|
+
/* .backend_apply = */ llama_sampler_top_k_backend_apply,
|
|
1345
|
+
/* .backend_set_input = */ nullptr,
|
|
752
1346
|
};
|
|
753
1347
|
|
|
754
1348
|
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
1349
|
+
const bool is_empty = (k <= 0);
|
|
1350
|
+
|
|
1351
|
+
if (is_empty) {
|
|
1352
|
+
return llama_sampler_init_empty("?top-k");
|
|
1353
|
+
}
|
|
1354
|
+
|
|
755
1355
|
return llama_sampler_init(
|
|
756
1356
|
/* .iface = */ &llama_sampler_top_k_i,
|
|
757
1357
|
/* .ctx = */ new llama_sampler_top_k {
|
|
1358
|
+
("top-k"),
|
|
758
1359
|
/* .k = */ k,
|
|
759
1360
|
}
|
|
760
1361
|
);
|
|
@@ -762,15 +1363,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
|
762
1363
|
|
|
763
1364
|
// top-p
|
|
764
1365
|
|
|
765
|
-
struct llama_sampler_top_p {
|
|
1366
|
+
struct llama_sampler_top_p : public llama_sampler_backend {
|
|
766
1367
|
const float p;
|
|
767
1368
|
const size_t min_keep;
|
|
768
1369
|
|
|
769
1370
|
std::vector<llama_token_data> buf_sort;
|
|
770
1371
|
};
|
|
771
1372
|
|
|
772
|
-
static const char * llama_sampler_top_p_name(const struct llama_sampler *
|
|
773
|
-
|
|
1373
|
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
|
|
1374
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1375
|
+
return sctx->get_name();
|
|
774
1376
|
}
|
|
775
1377
|
|
|
776
1378
|
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -837,19 +1439,118 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|
|
837
1439
|
delete (llama_sampler_top_p *) smpl->ctx;
|
|
838
1440
|
}
|
|
839
1441
|
|
|
1442
|
+
static bool llama_sampler_top_p_backend_init(
|
|
1443
|
+
struct llama_sampler * smpl,
|
|
1444
|
+
ggml_backend_buffer_type_t buft) {
|
|
1445
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1446
|
+
|
|
1447
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1448
|
+
|
|
1449
|
+
sctx->init(res);
|
|
1450
|
+
|
|
1451
|
+
return res;
|
|
1452
|
+
}
|
|
1453
|
+
|
|
1454
|
+
static void llama_sampler_top_p_backend_apply(
|
|
1455
|
+
struct llama_sampler * smpl,
|
|
1456
|
+
struct ggml_context * ctx,
|
|
1457
|
+
struct ggml_cgraph * gf,
|
|
1458
|
+
struct llama_sampler_data * data) {
|
|
1459
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1460
|
+
|
|
1461
|
+
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
|
|
1462
|
+
GGML_ASSERT(ggml_nrows(a) == 1);
|
|
1463
|
+
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
|
|
1464
|
+
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
|
|
1465
|
+
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
|
|
1466
|
+
};
|
|
1467
|
+
|
|
1468
|
+
// Get the sorted logits in descending order.
|
|
1469
|
+
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
|
|
1470
|
+
ggml_set_name(sorted_idx, "top_p_sorted_idx");
|
|
1471
|
+
|
|
1472
|
+
// Do the sorting via reshape + get_rows
|
|
1473
|
+
struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
|
|
1474
|
+
ggml_set_name(sorted_logits, "top_p_sorted_logits");
|
|
1475
|
+
|
|
1476
|
+
struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
|
|
1477
|
+
ggml_set_name(softmax, "top_p_softmax");
|
|
1478
|
+
|
|
1479
|
+
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
|
|
1480
|
+
if (data->candidates) {
|
|
1481
|
+
data->candidates = ggml_sort(data->candidates, sorted_idx);
|
|
1482
|
+
} else {
|
|
1483
|
+
data->candidates = sorted_idx;
|
|
1484
|
+
}
|
|
1485
|
+
ggml_set_name(data->candidates, "top_p_candidates");
|
|
1486
|
+
|
|
1487
|
+
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
|
|
1488
|
+
struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
|
|
1489
|
+
ggml_set_name(cdf, "top_p_cdf");
|
|
1490
|
+
|
|
1491
|
+
// Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
|
|
1492
|
+
struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
|
|
1493
|
+
ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
|
|
1494
|
+
|
|
1495
|
+
struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
|
|
1496
|
+
ggml_set_name(mask, "top_p_mask");
|
|
1497
|
+
|
|
1498
|
+
// Taking the sum of the mask gives us the sum of elements after the threshold
|
|
1499
|
+
// we are interested in.
|
|
1500
|
+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
|
1501
|
+
ggml_set_name(idxf, "top_p_index_f32");
|
|
1502
|
+
|
|
1503
|
+
// prevent out-of-bounds access
|
|
1504
|
+
idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
|
|
1505
|
+
|
|
1506
|
+
// construct ones tensor to set the value in the mask
|
|
1507
|
+
struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
|
|
1508
|
+
ggml_set_name(ones, "top_p_ones");
|
|
1509
|
+
|
|
1510
|
+
// Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
|
|
1511
|
+
struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
|
|
1512
|
+
|
|
1513
|
+
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
|
|
1514
|
+
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
|
|
1515
|
+
|
|
1516
|
+
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
|
|
1517
|
+
// top_p_bias = (mask * 1e9f) - 1e9f.
|
|
1518
|
+
// So entries in the mask that we want to discard will become -1e9f, and
|
|
1519
|
+
// others will be 0 (meaning that will not effect the logits).
|
|
1520
|
+
const float large_val = 1e9f;
|
|
1521
|
+
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
|
|
1522
|
+
ggml_set_name(top_p_bias, "top_p_bias");
|
|
1523
|
+
|
|
1524
|
+
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
|
|
1525
|
+
ggml_set_name(data->logits, "top_p_logits");
|
|
1526
|
+
|
|
1527
|
+
GGML_UNUSED(gf);
|
|
1528
|
+
}
|
|
1529
|
+
|
|
840
1530
|
static struct llama_sampler_i llama_sampler_top_p_i = {
|
|
841
|
-
/* .name
|
|
842
|
-
/* .accept
|
|
843
|
-
/* .apply
|
|
844
|
-
/* .reset
|
|
845
|
-
/* .clone
|
|
846
|
-
/* .free
|
|
1531
|
+
/* .name = */ llama_sampler_top_p_name,
|
|
1532
|
+
/* .accept = */ nullptr,
|
|
1533
|
+
/* .apply = */ llama_sampler_top_p_apply,
|
|
1534
|
+
/* .reset = */ nullptr,
|
|
1535
|
+
/* .clone = */ llama_sampler_top_p_clone,
|
|
1536
|
+
/* .free = */ llama_sampler_top_p_free,
|
|
1537
|
+
/* .backend_init = */ llama_sampler_top_p_backend_init,
|
|
1538
|
+
/* .backend_accept = */ nullptr,
|
|
1539
|
+
/* .backend_apply = */ llama_sampler_top_p_backend_apply,
|
|
1540
|
+
/* .backend_set_input = */ nullptr,
|
|
847
1541
|
};
|
|
848
1542
|
|
|
849
1543
|
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
|
1544
|
+
const bool is_empty = p >= 1.0f;
|
|
1545
|
+
|
|
1546
|
+
if (is_empty) {
|
|
1547
|
+
return llama_sampler_init_empty("?top-p");
|
|
1548
|
+
}
|
|
1549
|
+
|
|
850
1550
|
return llama_sampler_init(
|
|
851
1551
|
/* .iface = */ &llama_sampler_top_p_i,
|
|
852
1552
|
/* .ctx = */ new llama_sampler_top_p {
|
|
1553
|
+
("top-p"),
|
|
853
1554
|
/* .p = */ p,
|
|
854
1555
|
/* .min_keep = */ min_keep,
|
|
855
1556
|
/* .buf_sort = */ {},
|
|
@@ -859,13 +1560,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
|
|
859
1560
|
|
|
860
1561
|
// min-p
|
|
861
1562
|
|
|
862
|
-
struct llama_sampler_min_p {
|
|
1563
|
+
struct llama_sampler_min_p : public llama_sampler_backend {
|
|
863
1564
|
const float p;
|
|
864
1565
|
const size_t min_keep;
|
|
865
1566
|
};
|
|
866
1567
|
|
|
867
|
-
static const char * llama_sampler_min_p_name(const struct llama_sampler *
|
|
868
|
-
|
|
1568
|
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
|
|
1569
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1570
|
+
return sctx->get_name();
|
|
869
1571
|
}
|
|
870
1572
|
|
|
871
1573
|
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -931,19 +1633,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|
|
931
1633
|
delete (llama_sampler_min_p *) smpl->ctx;
|
|
932
1634
|
}
|
|
933
1635
|
|
|
1636
|
+
static bool llama_sampler_min_p_backend_init(
|
|
1637
|
+
struct llama_sampler * smpl,
|
|
1638
|
+
ggml_backend_buffer_type_t buft) {
|
|
1639
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1640
|
+
|
|
1641
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1642
|
+
|
|
1643
|
+
sctx->init(res);
|
|
1644
|
+
|
|
1645
|
+
return res;
|
|
1646
|
+
}
|
|
1647
|
+
|
|
1648
|
+
static void llama_sampler_min_p_backend_apply(
|
|
1649
|
+
struct llama_sampler * smpl,
|
|
1650
|
+
struct ggml_context * ctx,
|
|
1651
|
+
struct ggml_cgraph * gf,
|
|
1652
|
+
struct llama_sampler_data * data) {
|
|
1653
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1654
|
+
|
|
1655
|
+
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
|
1656
|
+
ggml_set_name(max_idx, "max_idx");
|
|
1657
|
+
|
|
1658
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1659
|
+
ggml_set_name(logits_rows, "logits_rows");
|
|
1660
|
+
|
|
1661
|
+
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
1662
|
+
ggml_set_name(max_logit, "max_logit");
|
|
1663
|
+
|
|
1664
|
+
// Calculate the threshold value.
|
|
1665
|
+
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
|
|
1666
|
+
ggml_set_name(threshold, "min_p_threshold");
|
|
1667
|
+
|
|
1668
|
+
// Subtract the threshold from logits.
|
|
1669
|
+
struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
|
|
1670
|
+
|
|
1671
|
+
// Create a mask where logits below the threshold are 0 (discard),
|
|
1672
|
+
// and others are 1 (keep).
|
|
1673
|
+
struct ggml_tensor * mask = ggml_step(ctx, sub);
|
|
1674
|
+
ggml_set_name(mask, "min_p_mask");
|
|
1675
|
+
|
|
1676
|
+
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
|
|
1677
|
+
// min_p_bias = (mask * 1e9f) - 1e9f.
|
|
1678
|
+
// So entries in the mask that we want to discard will become -1e9f, and
|
|
1679
|
+
// others will be 0 (meaning that will not effect the logits).
|
|
1680
|
+
const float large_val = 1e9f;
|
|
1681
|
+
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
|
|
1682
|
+
ggml_set_name(min_p_bias, "min_p_bias");
|
|
1683
|
+
|
|
1684
|
+
// Add the min_p bias to the logits.
|
|
1685
|
+
data->logits = ggml_add(ctx, data->logits, min_p_bias);
|
|
1686
|
+
ggml_set_name(data->logits, "min_p_logits");
|
|
1687
|
+
|
|
1688
|
+
GGML_UNUSED(gf);
|
|
1689
|
+
}
|
|
1690
|
+
|
|
934
1691
|
static struct llama_sampler_i llama_sampler_min_p_i = {
|
|
935
|
-
/* .name
|
|
936
|
-
/* .accept
|
|
937
|
-
/* .apply
|
|
938
|
-
/* .reset
|
|
939
|
-
/* .clone
|
|
940
|
-
/* .free
|
|
1692
|
+
/* .name = */ llama_sampler_min_p_name,
|
|
1693
|
+
/* .accept = */ nullptr,
|
|
1694
|
+
/* .apply = */ llama_sampler_min_p_apply,
|
|
1695
|
+
/* .reset = */ nullptr,
|
|
1696
|
+
/* .clone = */ llama_sampler_min_p_clone,
|
|
1697
|
+
/* .free = */ llama_sampler_min_p_free,
|
|
1698
|
+
/* .backend_init = */ llama_sampler_min_p_backend_init,
|
|
1699
|
+
/* .backend_accept = */ nullptr,
|
|
1700
|
+
/* .backend_apply = */ llama_sampler_min_p_backend_apply,
|
|
1701
|
+
/* .backend_set_input = */ nullptr,
|
|
941
1702
|
};
|
|
942
1703
|
|
|
943
1704
|
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
|
1705
|
+
const bool is_empty = (p <= 0.0f);
|
|
1706
|
+
|
|
1707
|
+
if (is_empty) {
|
|
1708
|
+
return llama_sampler_init_empty("?min-p");
|
|
1709
|
+
}
|
|
1710
|
+
|
|
944
1711
|
return llama_sampler_init(
|
|
945
1712
|
/* .iface = */ &llama_sampler_min_p_i,
|
|
946
1713
|
/* .ctx = */ new llama_sampler_min_p {
|
|
1714
|
+
("min-p"),
|
|
947
1715
|
/* .p = */ p,
|
|
948
1716
|
/* .min_keep = */ min_keep,
|
|
949
1717
|
}
|
|
@@ -1031,15 +1799,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
|
|
1031
1799
|
}
|
|
1032
1800
|
|
|
1033
1801
|
static struct llama_sampler_i llama_sampler_typical_i = {
|
|
1034
|
-
/* .name
|
|
1035
|
-
/* .accept
|
|
1036
|
-
/* .apply
|
|
1037
|
-
/* .reset
|
|
1038
|
-
/* .clone
|
|
1039
|
-
/* .free
|
|
1802
|
+
/* .name = */ llama_sampler_typical_name,
|
|
1803
|
+
/* .accept = */ nullptr,
|
|
1804
|
+
/* .apply = */ llama_sampler_typical_apply,
|
|
1805
|
+
/* .reset = */ nullptr,
|
|
1806
|
+
/* .clone = */ llama_sampler_typical_clone,
|
|
1807
|
+
/* .free = */ llama_sampler_typical_free,
|
|
1808
|
+
/* .backend_init = */ nullptr,
|
|
1809
|
+
/* .backend_accept = */ nullptr,
|
|
1810
|
+
/* .backend_apply = */ nullptr,
|
|
1811
|
+
/* .backend_set_input = */ nullptr,
|
|
1040
1812
|
};
|
|
1041
1813
|
|
|
1042
1814
|
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
1815
|
+
const bool is_empty = (p >= 1.0f);
|
|
1816
|
+
|
|
1817
|
+
if (is_empty) {
|
|
1818
|
+
return llama_sampler_init_empty("?typical");
|
|
1819
|
+
}
|
|
1820
|
+
|
|
1043
1821
|
return llama_sampler_init(
|
|
1044
1822
|
/* .iface = */ &llama_sampler_typical_i,
|
|
1045
1823
|
/* .ctx = */ new llama_sampler_typical {
|
|
@@ -1051,12 +1829,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
|
1051
1829
|
|
|
1052
1830
|
// temp
|
|
1053
1831
|
|
|
1054
|
-
struct llama_sampler_temp {
|
|
1832
|
+
struct llama_sampler_temp : public llama_sampler_backend {
|
|
1055
1833
|
const float temp;
|
|
1056
1834
|
};
|
|
1057
1835
|
|
|
1058
|
-
static const char * llama_sampler_temp_name(const struct llama_sampler *
|
|
1059
|
-
|
|
1836
|
+
static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
|
|
1837
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1838
|
+
return sctx->get_name();
|
|
1060
1839
|
}
|
|
1061
1840
|
|
|
1062
1841
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -1074,19 +1853,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
|
|
1074
1853
|
delete (llama_sampler_temp *) smpl->ctx;
|
|
1075
1854
|
}
|
|
1076
1855
|
|
|
1856
|
+
static void llama_sampler_backend_temp_sampling(
|
|
1857
|
+
struct ggml_context * ctx,
|
|
1858
|
+
struct ggml_cgraph * gf,
|
|
1859
|
+
struct llama_sampler_data * data,
|
|
1860
|
+
float temp) {
|
|
1861
|
+
if (temp <= 0.0f) {
|
|
1862
|
+
// Find the most probable token index.
|
|
1863
|
+
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
|
1864
|
+
ggml_set_name(max_idx, "temp_max_idx");
|
|
1865
|
+
|
|
1866
|
+
if (data->candidates) {
|
|
1867
|
+
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
|
1868
|
+
data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
|
|
1869
|
+
} else {
|
|
1870
|
+
data->candidates = max_idx;
|
|
1871
|
+
}
|
|
1872
|
+
|
|
1873
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1874
|
+
data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
1875
|
+
|
|
1876
|
+
return;
|
|
1877
|
+
}
|
|
1878
|
+
|
|
1879
|
+
data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
|
|
1880
|
+
|
|
1881
|
+
GGML_UNUSED(gf);
|
|
1882
|
+
}
|
|
1883
|
+
|
|
1884
|
+
static bool llama_sampler_temp_backend_init(
|
|
1885
|
+
struct llama_sampler * smpl,
|
|
1886
|
+
ggml_backend_buffer_type_t buft) {
|
|
1887
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1888
|
+
|
|
1889
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1890
|
+
|
|
1891
|
+
sctx->init(res);
|
|
1892
|
+
|
|
1893
|
+
return res;
|
|
1894
|
+
}
|
|
1895
|
+
|
|
1896
|
+
static void llama_sampler_temp_backend_apply(
|
|
1897
|
+
struct llama_sampler * smpl,
|
|
1898
|
+
struct ggml_context * ctx,
|
|
1899
|
+
struct ggml_cgraph * gf,
|
|
1900
|
+
struct llama_sampler_data * data) {
|
|
1901
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1902
|
+
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
|
1903
|
+
}
|
|
1904
|
+
|
|
1077
1905
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
|
1078
|
-
/* .name
|
|
1079
|
-
/* .accept
|
|
1080
|
-
/* .apply
|
|
1081
|
-
/* .reset
|
|
1082
|
-
/* .clone
|
|
1083
|
-
/* .free
|
|
1906
|
+
/* .name = */ llama_sampler_temp_name,
|
|
1907
|
+
/* .accept = */ nullptr,
|
|
1908
|
+
/* .apply = */ llama_sampler_temp_apply,
|
|
1909
|
+
/* .reset = */ nullptr,
|
|
1910
|
+
/* .clone = */ llama_sampler_temp_clone,
|
|
1911
|
+
/* .free = */ llama_sampler_temp_free,
|
|
1912
|
+
/* .backend_init = */ llama_sampler_temp_backend_init,
|
|
1913
|
+
/* .backend_accept = */ nullptr,
|
|
1914
|
+
/* .backend_apply = */ llama_sampler_temp_backend_apply,
|
|
1915
|
+
/* .backend_set_input = */ nullptr,
|
|
1084
1916
|
};
|
|
1085
1917
|
|
|
1086
1918
|
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
1919
|
+
const bool is_empty = temp == 1.0f;
|
|
1920
|
+
|
|
1921
|
+
if (is_empty) {
|
|
1922
|
+
return llama_sampler_init_empty("?temp");
|
|
1923
|
+
}
|
|
1924
|
+
|
|
1087
1925
|
return llama_sampler_init(
|
|
1088
1926
|
/* .iface = */ &llama_sampler_temp_i,
|
|
1089
1927
|
/* .ctx = */ new llama_sampler_temp {
|
|
1928
|
+
("temp"),
|
|
1090
1929
|
/*.temp = */ temp,
|
|
1091
1930
|
}
|
|
1092
1931
|
);
|
|
@@ -1094,14 +1933,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
|
1094
1933
|
|
|
1095
1934
|
// temp-ext
|
|
1096
1935
|
|
|
1097
|
-
struct llama_sampler_temp_ext {
|
|
1936
|
+
struct llama_sampler_temp_ext : public llama_sampler_backend {
|
|
1098
1937
|
const float temp;
|
|
1099
1938
|
const float delta;
|
|
1100
1939
|
const float exponent;
|
|
1101
1940
|
};
|
|
1102
1941
|
|
|
1103
|
-
static const char * llama_sampler_temp_ext_name(const struct llama_sampler *
|
|
1104
|
-
|
|
1942
|
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
|
|
1943
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
1944
|
+
return sctx->get_name();
|
|
1105
1945
|
}
|
|
1106
1946
|
|
|
1107
1947
|
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -1184,24 +2024,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|
|
1184
2024
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
|
1185
2025
|
}
|
|
1186
2026
|
|
|
2027
|
+
static bool llama_sampler_temp_ext_backend_init(
|
|
2028
|
+
struct llama_sampler * smpl,
|
|
2029
|
+
ggml_backend_buffer_type_t buft) {
|
|
2030
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
2031
|
+
|
|
2032
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
2033
|
+
|
|
2034
|
+
sctx->init(res);
|
|
2035
|
+
|
|
2036
|
+
return res;
|
|
2037
|
+
}
|
|
2038
|
+
|
|
2039
|
+
static void llama_sampler_temp_ext_backend_apply(
|
|
2040
|
+
struct llama_sampler * smpl,
|
|
2041
|
+
struct ggml_context * ctx,
|
|
2042
|
+
struct ggml_cgraph * gf,
|
|
2043
|
+
struct llama_sampler_data * data) {
|
|
2044
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
2045
|
+
|
|
2046
|
+
// Revert to standard temperature scaling if delta or temp are non-positive.
|
|
2047
|
+
if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
|
|
2048
|
+
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
|
2049
|
+
return;
|
|
2050
|
+
}
|
|
2051
|
+
|
|
2052
|
+
// Calculate min_temp, max_temp, and max_entropy.
|
|
2053
|
+
const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
|
|
2054
|
+
const float max_temp = sctx->temp + sctx->delta;
|
|
2055
|
+
const float max_entropy = logf(data->logits->ne[0]);
|
|
2056
|
+
|
|
2057
|
+
// Calculate the probabilities.
|
|
2058
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
2059
|
+
ggml_set_name(probs, "temp_ext_softmax_probs");
|
|
2060
|
+
|
|
2061
|
+
// Clamp probabilities to avoid log(0) which would give -inf
|
|
2062
|
+
struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
|
|
2063
|
+
ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
|
|
2064
|
+
|
|
2065
|
+
// Calculate the entropy, entropy = -Σ(p * log(p)).
|
|
2066
|
+
struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
|
|
2067
|
+
struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
|
|
2068
|
+
struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
|
|
2069
|
+
struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
|
|
2070
|
+
ggml_set_name(log_probs, "temp_ext_log_probs");
|
|
2071
|
+
ggml_set_name(p_log_p, "temp_ext_p_log_p");
|
|
2072
|
+
ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
|
|
2073
|
+
ggml_set_name(entropy, "temp_ext_entropy");
|
|
2074
|
+
|
|
2075
|
+
// Normalize the entropy, norm_entropy = entropy / max_entropy
|
|
2076
|
+
struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
|
|
2077
|
+
ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
|
|
2078
|
+
|
|
2079
|
+
// Calculate the dynamic temperature:
|
|
2080
|
+
// dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
|
|
2081
|
+
//
|
|
2082
|
+
// Calculate powf(normalized_entropy, exponent) as
|
|
2083
|
+
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
|
|
2084
|
+
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
|
|
2085
|
+
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
|
|
2086
|
+
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
|
|
2087
|
+
// With pow_entropy computed we can now compute dyn_temp, scaling by
|
|
2088
|
+
// (max_temp - min_temp) and then adding min_temp.
|
|
2089
|
+
struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
|
|
2090
|
+
ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
|
|
2091
|
+
ggml_set_name(scaled_log, "temp_ext_scaled_log");
|
|
2092
|
+
ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
|
|
2093
|
+
ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
|
|
2094
|
+
|
|
2095
|
+
// Scale the logits by the dynamic temperature
|
|
2096
|
+
struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
|
|
2097
|
+
ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
|
|
2098
|
+
|
|
2099
|
+
data->logits = scaled_logits;
|
|
2100
|
+
}
|
|
2101
|
+
|
|
1187
2102
|
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|
1188
|
-
/* .name
|
|
1189
|
-
/* .accept
|
|
1190
|
-
/* .apply
|
|
1191
|
-
/* .reset
|
|
1192
|
-
/* .clone
|
|
1193
|
-
/* .free
|
|
2103
|
+
/* .name = */ llama_sampler_temp_ext_name,
|
|
2104
|
+
/* .accept = */ nullptr,
|
|
2105
|
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
|
2106
|
+
/* .reset = */ nullptr,
|
|
2107
|
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
|
2108
|
+
/* .free = */ llama_sampler_temp_ext_free,
|
|
2109
|
+
/* .backend_init = */ llama_sampler_temp_ext_backend_init,
|
|
2110
|
+
/* .backend_accept = */ nullptr,
|
|
2111
|
+
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
|
|
2112
|
+
/* .backend_set_input = */ nullptr,
|
|
1194
2113
|
};
|
|
1195
2114
|
|
|
1196
2115
|
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
|
1197
|
-
|
|
2116
|
+
const bool is_empty = temp == 1.0f && delta <= 0.0f;
|
|
2117
|
+
|
|
2118
|
+
if (is_empty) {
|
|
2119
|
+
return llama_sampler_init_empty("?temp-ext");
|
|
2120
|
+
}
|
|
2121
|
+
|
|
2122
|
+
auto * res = llama_sampler_init(
|
|
1198
2123
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
|
1199
2124
|
/* .ctx = */ new llama_sampler_temp_ext {
|
|
2125
|
+
("temp-ext"),
|
|
1200
2126
|
/* .temp = */ temp,
|
|
1201
2127
|
/* .delta = */ delta,
|
|
1202
2128
|
/* .exponent = */ exponent,
|
|
1203
2129
|
}
|
|
1204
2130
|
);
|
|
2131
|
+
|
|
2132
|
+
return res;
|
|
1205
2133
|
}
|
|
1206
2134
|
|
|
1207
2135
|
// xtc
|
|
@@ -1214,7 +2142,7 @@ struct llama_sampler_xtc {
|
|
|
1214
2142
|
const uint32_t seed;
|
|
1215
2143
|
uint32_t seed_cur;
|
|
1216
2144
|
|
|
1217
|
-
std::mt19937
|
|
2145
|
+
std::mt19937 rng;
|
|
1218
2146
|
};
|
|
1219
2147
|
|
|
1220
2148
|
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
|
@@ -1279,16 +2207,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
|
|
1279
2207
|
}
|
|
1280
2208
|
|
|
1281
2209
|
static struct llama_sampler_i llama_sampler_xtc_i = {
|
|
1282
|
-
/* .name
|
|
1283
|
-
/* .accept
|
|
1284
|
-
/* .apply
|
|
1285
|
-
/* .reset
|
|
1286
|
-
/* .clone
|
|
1287
|
-
/* .free
|
|
2210
|
+
/* .name = */ llama_sampler_xtc_name,
|
|
2211
|
+
/* .accept = */ nullptr,
|
|
2212
|
+
/* .apply = */ llama_sample_xtc_apply,
|
|
2213
|
+
/* .reset = */ llama_sampler_xtc_reset,
|
|
2214
|
+
/* .clone = */ llama_sampler_xtc_clone,
|
|
2215
|
+
/* .free = */ llama_sampler_xtc_free,
|
|
2216
|
+
/* .backend_init = */ nullptr,
|
|
2217
|
+
/* .backend_accept = */ nullptr,
|
|
2218
|
+
/* .backend_apply = */ nullptr,
|
|
2219
|
+
/* .backend_set_input = */ nullptr,
|
|
1288
2220
|
};
|
|
1289
2221
|
|
|
1290
2222
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
|
1291
|
-
|
|
2223
|
+
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
|
2224
|
+
|
|
2225
|
+
if (is_empty) {
|
|
2226
|
+
return llama_sampler_init_empty("?xtc");
|
|
2227
|
+
}
|
|
2228
|
+
|
|
2229
|
+
const auto seed_cur = get_rng_seed(seed);
|
|
2230
|
+
|
|
1292
2231
|
return llama_sampler_init(
|
|
1293
2232
|
/* .iface = */ &llama_sampler_xtc_i,
|
|
1294
2233
|
/* .ctx = */ new llama_sampler_xtc {
|
|
@@ -1387,16 +2326,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|
|
1387
2326
|
}
|
|
1388
2327
|
|
|
1389
2328
|
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|
1390
|
-
/* .name
|
|
1391
|
-
/* .accept
|
|
1392
|
-
/* .apply
|
|
1393
|
-
/* .reset
|
|
1394
|
-
/* .clone
|
|
1395
|
-
/* .free
|
|
2329
|
+
/* .name = */ llama_sampler_mirostat_name,
|
|
2330
|
+
/* .accept = */ nullptr,
|
|
2331
|
+
/* .apply = */ llama_sampler_mirostat_apply,
|
|
2332
|
+
/* .reset = */ llama_sampler_mirostat_reset,
|
|
2333
|
+
/* .clone = */ llama_sampler_mirostat_clone,
|
|
2334
|
+
/* .free = */ llama_sampler_mirostat_free,
|
|
2335
|
+
/* .backend_init = */ nullptr,
|
|
2336
|
+
/* .backend_accept = */ nullptr,
|
|
2337
|
+
/* .backend_apply = */ nullptr,
|
|
2338
|
+
/* .backend_set_input = */ nullptr,
|
|
1396
2339
|
};
|
|
1397
2340
|
|
|
1398
2341
|
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
|
1399
|
-
auto seed_cur = get_rng_seed(seed);
|
|
2342
|
+
const auto seed_cur = get_rng_seed(seed);
|
|
2343
|
+
|
|
1400
2344
|
return llama_sampler_init(
|
|
1401
2345
|
/* .iface = */ &llama_sampler_mirostat_i,
|
|
1402
2346
|
/* .ctx = */ new llama_sampler_mirostat {
|
|
@@ -1486,12 +2430,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
|
|
1486
2430
|
}
|
|
1487
2431
|
|
|
1488
2432
|
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|
1489
|
-
/* .name
|
|
1490
|
-
/* .accept
|
|
1491
|
-
/* .apply
|
|
1492
|
-
/* .reset
|
|
1493
|
-
/* .clone
|
|
1494
|
-
/* .free
|
|
2433
|
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
|
2434
|
+
/* .accept = */ nullptr,
|
|
2435
|
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
|
2436
|
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
|
2437
|
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
|
2438
|
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
|
2439
|
+
/* .backend_init = */ nullptr,
|
|
2440
|
+
/* .backend_accept = */ nullptr,
|
|
2441
|
+
/* .backend_apply = */ nullptr,
|
|
2442
|
+
/* .backend_set_input = */ nullptr,
|
|
1495
2443
|
};
|
|
1496
2444
|
|
|
1497
2445
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
|
@@ -1603,12 +2551,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
|
|
1603
2551
|
}
|
|
1604
2552
|
|
|
1605
2553
|
static struct llama_sampler_i llama_sampler_grammar_i = {
|
|
1606
|
-
/* .name
|
|
1607
|
-
/* .accept
|
|
1608
|
-
/* .apply
|
|
1609
|
-
/* .reset
|
|
1610
|
-
/* .clone
|
|
1611
|
-
/* .free
|
|
2554
|
+
/* .name = */ llama_sampler_grammar_name,
|
|
2555
|
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
|
2556
|
+
/* .apply = */ llama_sampler_grammar_apply,
|
|
2557
|
+
/* .reset = */ llama_sampler_grammar_reset,
|
|
2558
|
+
/* .clone = */ llama_sampler_grammar_clone,
|
|
2559
|
+
/* .free = */ llama_sampler_grammar_free,
|
|
2560
|
+
/* .backend_init = */ nullptr,
|
|
2561
|
+
/* .backend_accept = */ nullptr,
|
|
2562
|
+
/* .backend_apply = */ nullptr,
|
|
2563
|
+
/* .backend_set_input = */ nullptr,
|
|
1612
2564
|
};
|
|
1613
2565
|
|
|
1614
2566
|
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
@@ -1625,10 +2577,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
1625
2577
|
auto * ctx = new llama_sampler_grammar;
|
|
1626
2578
|
|
|
1627
2579
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
2580
|
+
std::string trigger_pattern;
|
|
2581
|
+
llama_grammar * grammar = nullptr;
|
|
1628
2582
|
// TODO: remove trigger_words support.
|
|
1629
2583
|
if (trigger_words != nullptr && num_trigger_words > 0) {
|
|
1630
2584
|
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
|
1631
|
-
|
|
2585
|
+
trigger_pattern = "[\\s\\S]*?(";
|
|
1632
2586
|
for (size_t i = 0; i < num_trigger_words; ++i) {
|
|
1633
2587
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
|
1634
2588
|
if (i > 0) {
|
|
@@ -1637,15 +2591,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
1637
2591
|
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
|
1638
2592
|
}
|
|
1639
2593
|
trigger_pattern += ")[\\s\\S]*";
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
2594
|
+
|
|
2595
|
+
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
|
2596
|
+
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
|
2597
|
+
} else {
|
|
2598
|
+
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
|
1643
2599
|
}
|
|
1644
2600
|
*ctx = {
|
|
1645
2601
|
/* .vocab = */ vocab,
|
|
1646
2602
|
/* .grammar_str = */ grammar_str,
|
|
1647
2603
|
/* .grammar_root = */ grammar_root,
|
|
1648
|
-
/* .grammar = */
|
|
2604
|
+
/* .grammar = */ grammar,
|
|
1649
2605
|
};
|
|
1650
2606
|
if (!ctx->grammar) {
|
|
1651
2607
|
delete ctx;
|
|
@@ -1806,12 +2762,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
|
|
1806
2762
|
}
|
|
1807
2763
|
|
|
1808
2764
|
static struct llama_sampler_i llama_sampler_penalties_i = {
|
|
1809
|
-
/* .name
|
|
1810
|
-
/* .accept
|
|
1811
|
-
/* .apply
|
|
1812
|
-
/* .reset
|
|
1813
|
-
/* .clone
|
|
1814
|
-
/* .free
|
|
2765
|
+
/* .name = */ llama_sampler_penalties_name,
|
|
2766
|
+
/* .accept = */ llama_sampler_penalties_accept,
|
|
2767
|
+
/* .apply = */ llama_sampler_penalties_apply,
|
|
2768
|
+
/* .reset = */ llama_sampler_penalties_reset,
|
|
2769
|
+
/* .clone = */ llama_sampler_penalties_clone,
|
|
2770
|
+
/* .free = */ llama_sampler_penalties_free,
|
|
2771
|
+
/* .backend_init = */ nullptr,
|
|
2772
|
+
/* .backend_accept = */ nullptr,
|
|
2773
|
+
/* .backend_apply = */ nullptr,
|
|
2774
|
+
/* .backend_set_input = */ nullptr,
|
|
1815
2775
|
};
|
|
1816
2776
|
|
|
1817
2777
|
struct llama_sampler * llama_sampler_init_penalties(
|
|
@@ -1821,6 +2781,12 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
|
1821
2781
|
float penalty_present) {
|
|
1822
2782
|
penalty_last_n = std::max(penalty_last_n, 0);
|
|
1823
2783
|
|
|
2784
|
+
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
|
2785
|
+
|
|
2786
|
+
if (is_empty) {
|
|
2787
|
+
return llama_sampler_init_empty("?penalties");
|
|
2788
|
+
}
|
|
2789
|
+
|
|
1824
2790
|
return llama_sampler_init(
|
|
1825
2791
|
/* .iface = */ &llama_sampler_penalties_i,
|
|
1826
2792
|
/* .ctx = */ new llama_sampler_penalties {
|
|
@@ -1858,9 +2824,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|
|
1858
2824
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1859
2825
|
// Only count non-negative infinity values
|
|
1860
2826
|
if (cur_p->data[i].logit != -INFINITY) {
|
|
1861
|
-
|
|
1862
|
-
max = cur_p->data[i].logit;
|
|
1863
|
-
}
|
|
2827
|
+
max = std::max(max, cur_p->data[i].logit);
|
|
1864
2828
|
logits_sum += cur_p->data[i].logit;
|
|
1865
2829
|
valid_count++;
|
|
1866
2830
|
}
|
|
@@ -1897,15 +2861,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
|
|
1897
2861
|
}
|
|
1898
2862
|
|
|
1899
2863
|
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
|
1900
|
-
/* .name
|
|
1901
|
-
/* .accept
|
|
1902
|
-
/* .apply
|
|
1903
|
-
/* .reset
|
|
1904
|
-
/* .clone
|
|
1905
|
-
/* .free
|
|
2864
|
+
/* .name = */ llama_sampler_top_n_sigma_name,
|
|
2865
|
+
/* .accept = */ nullptr,
|
|
2866
|
+
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
|
2867
|
+
/* .reset = */ nullptr,
|
|
2868
|
+
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
|
2869
|
+
/* .free = */ llama_sampler_top_n_sigma_free,
|
|
2870
|
+
/* .backend_init = */ nullptr,
|
|
2871
|
+
/* .backend_accept = */ nullptr,
|
|
2872
|
+
/* .backend_apply = */ nullptr,
|
|
2873
|
+
/* .backend_set_input = */ nullptr,
|
|
1906
2874
|
};
|
|
1907
2875
|
|
|
1908
2876
|
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
|
2877
|
+
const bool is_empty = (n <= 0.0f);
|
|
2878
|
+
|
|
2879
|
+
if (is_empty) {
|
|
2880
|
+
return llama_sampler_init_empty("?top-n-sigma");
|
|
2881
|
+
}
|
|
2882
|
+
|
|
1909
2883
|
return llama_sampler_init(
|
|
1910
2884
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
|
1911
2885
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
|
@@ -2227,12 +3201,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
|
|
2227
3201
|
}
|
|
2228
3202
|
|
|
2229
3203
|
static struct llama_sampler_i llama_sampler_dry_i = {
|
|
2230
|
-
/* .name
|
|
2231
|
-
/* .accept
|
|
2232
|
-
/* .apply
|
|
2233
|
-
/* .reset
|
|
2234
|
-
/* .clone
|
|
2235
|
-
/* .free
|
|
3204
|
+
/* .name = */ llama_sampler_dry_name,
|
|
3205
|
+
/* .accept = */ llama_sampler_dry_accept,
|
|
3206
|
+
/* .apply = */ llama_sampler_dry_apply,
|
|
3207
|
+
/* .reset = */ llama_sampler_dry_reset,
|
|
3208
|
+
/* .clone = */ llama_sampler_dry_clone,
|
|
3209
|
+
/* .free = */ llama_sampler_dry_free,
|
|
3210
|
+
/* .backend_init = */ nullptr,
|
|
3211
|
+
/* .backend_accept = */ nullptr,
|
|
3212
|
+
/* .backend_apply = */ nullptr,
|
|
3213
|
+
/* .backend_set_input = */ nullptr,
|
|
2236
3214
|
};
|
|
2237
3215
|
|
|
2238
3216
|
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
|
@@ -2243,6 +3221,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|
|
2243
3221
|
|
|
2244
3222
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
|
2245
3223
|
|
|
3224
|
+
if (!dry_enabled) {
|
|
3225
|
+
return llama_sampler_init_empty("?dry");
|
|
3226
|
+
}
|
|
3227
|
+
|
|
2246
3228
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
|
2247
3229
|
// Process sequence breakers
|
|
2248
3230
|
for (size_t i = 0; i < num_breakers; ++i) {
|
|
@@ -2313,16 +3295,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|
|
2313
3295
|
|
|
2314
3296
|
// logit-bias
|
|
2315
3297
|
|
|
2316
|
-
struct llama_sampler_logit_bias {
|
|
3298
|
+
struct llama_sampler_logit_bias : public llama_sampler_backend {
|
|
2317
3299
|
const int32_t n_vocab;
|
|
2318
3300
|
|
|
2319
3301
|
const std::vector<llama_logit_bias> logit_bias;
|
|
2320
3302
|
|
|
2321
3303
|
std::vector<llama_logit_bias> to_search;
|
|
3304
|
+
|
|
3305
|
+
struct ggml_tensor * inp_logit_bias;
|
|
3306
|
+
struct ggml_tensor * inp_logit_idxs;
|
|
3307
|
+
|
|
3308
|
+
ggml_context_ptr inp_ctx;
|
|
3309
|
+
ggml_backend_buffer_ptr inp_buf;
|
|
2322
3310
|
};
|
|
2323
3311
|
|
|
2324
|
-
static const char * llama_sampler_logit_bias_name(const struct llama_sampler *
|
|
2325
|
-
|
|
3312
|
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
|
|
3313
|
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3314
|
+
return ctx->get_name();
|
|
2326
3315
|
}
|
|
2327
3316
|
|
|
2328
3317
|
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -2367,25 +3356,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
|
|
2367
3356
|
delete (llama_sampler_logit_bias *) smpl->ctx;
|
|
2368
3357
|
}
|
|
2369
3358
|
|
|
3359
|
+
static void llama_sampler_logit_bias_backend_apply(
|
|
3360
|
+
struct llama_sampler * smpl,
|
|
3361
|
+
struct ggml_context * ctx,
|
|
3362
|
+
struct ggml_cgraph * gf,
|
|
3363
|
+
struct llama_sampler_data * data) {
|
|
3364
|
+
GGML_UNUSED(gf);
|
|
3365
|
+
GGML_UNUSED(ctx);
|
|
3366
|
+
|
|
3367
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3368
|
+
if (sctx->logit_bias.empty()) {
|
|
3369
|
+
return;
|
|
3370
|
+
}
|
|
3371
|
+
|
|
3372
|
+
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
|
|
3373
|
+
|
|
3374
|
+
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
|
|
3375
|
+
cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
|
|
3376
|
+
cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
|
|
3377
|
+
|
|
3378
|
+
data->logits = ggml_add(ctx, data->logits, cur);
|
|
3379
|
+
}
|
|
3380
|
+
|
|
3381
|
+
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
|
|
3382
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3383
|
+
if (sctx->logit_bias.empty()) {
|
|
3384
|
+
return;
|
|
3385
|
+
}
|
|
3386
|
+
|
|
3387
|
+
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
|
3388
|
+
GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
|
|
3389
|
+
|
|
3390
|
+
const size_t n = sctx->logit_bias.size();
|
|
3391
|
+
|
|
3392
|
+
std::vector<float> data_logit_bias(n, 0.0f);
|
|
3393
|
+
std::vector<int32_t> data_logit_idxs(n, 0);
|
|
3394
|
+
for (size_t i = 0; i < n; ++i) {
|
|
3395
|
+
const auto & lb = sctx->logit_bias[i];
|
|
3396
|
+
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
|
3397
|
+
data_logit_bias[i] = lb.bias;
|
|
3398
|
+
data_logit_idxs[i] = lb.token;
|
|
3399
|
+
}
|
|
3400
|
+
|
|
3401
|
+
ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
|
3402
|
+
ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
|
|
3403
|
+
}
|
|
3404
|
+
|
|
3405
|
+
static bool llama_sampler_logit_bias_backend_init(
|
|
3406
|
+
struct llama_sampler * smpl,
|
|
3407
|
+
ggml_backend_buffer_type_t buft) {
|
|
3408
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3409
|
+
|
|
3410
|
+
sctx->init(true);
|
|
3411
|
+
|
|
3412
|
+
if (sctx->logit_bias.empty()) {
|
|
3413
|
+
return true;
|
|
3414
|
+
}
|
|
3415
|
+
|
|
3416
|
+
ggml_init_params params = {
|
|
3417
|
+
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
|
3418
|
+
/*.mem_buffer =*/ nullptr,
|
|
3419
|
+
/*.no_alloc =*/ true,
|
|
3420
|
+
};
|
|
3421
|
+
|
|
3422
|
+
sctx->inp_ctx.reset(ggml_init(params));
|
|
3423
|
+
|
|
3424
|
+
const size_t n = sctx->logit_bias.size();
|
|
3425
|
+
|
|
3426
|
+
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
|
|
3427
|
+
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
|
|
3428
|
+
ggml_set_input(sctx->inp_logit_bias);
|
|
3429
|
+
|
|
3430
|
+
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
|
|
3431
|
+
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
|
|
3432
|
+
ggml_set_input(sctx->inp_logit_idxs);
|
|
3433
|
+
|
|
3434
|
+
// Allocate all tensors from our context to the backend
|
|
3435
|
+
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
|
3436
|
+
|
|
3437
|
+
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
|
|
3438
|
+
|
|
3439
|
+
return true;
|
|
3440
|
+
}
|
|
3441
|
+
|
|
2370
3442
|
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
|
2371
|
-
/* .name
|
|
2372
|
-
/* .accept
|
|
2373
|
-
/* .apply
|
|
2374
|
-
/* .reset
|
|
2375
|
-
/* .clone
|
|
2376
|
-
/* .free
|
|
3443
|
+
/* .name = */ llama_sampler_logit_bias_name,
|
|
3444
|
+
/* .accept = */ nullptr,
|
|
3445
|
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
|
3446
|
+
/* .reset = */ nullptr,
|
|
3447
|
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
|
3448
|
+
/* .free = */ llama_sampler_logit_bias_free,
|
|
3449
|
+
/* .backend_init = */ llama_sampler_logit_bias_backend_init,
|
|
3450
|
+
/* .backend_accept = */ nullptr,
|
|
3451
|
+
/* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
|
|
3452
|
+
/* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
|
|
2377
3453
|
};
|
|
2378
3454
|
|
|
2379
3455
|
struct llama_sampler * llama_sampler_init_logit_bias(
|
|
2380
3456
|
int32_t n_vocab,
|
|
2381
3457
|
int32_t n_logit_bias,
|
|
2382
3458
|
const llama_logit_bias * logit_bias) {
|
|
3459
|
+
const bool is_empty = n_logit_bias <= 0;
|
|
3460
|
+
|
|
3461
|
+
if (is_empty) {
|
|
3462
|
+
return llama_sampler_init_empty("?logit-bias");
|
|
3463
|
+
}
|
|
3464
|
+
|
|
2383
3465
|
return llama_sampler_init(
|
|
2384
3466
|
/* .iface = */ &llama_sampler_logit_bias_i,
|
|
2385
3467
|
/* .ctx = */ new llama_sampler_logit_bias {
|
|
2386
|
-
|
|
2387
|
-
/* .
|
|
2388
|
-
/* .
|
|
3468
|
+
("logit-bias"),
|
|
3469
|
+
/* .n_vocab = */ n_vocab,
|
|
3470
|
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
|
3471
|
+
/* .to_search = */ {},
|
|
3472
|
+
/* .inp_logit_bias = */ nullptr,
|
|
3473
|
+
/* .inp_logit_idxs = */ nullptr,
|
|
3474
|
+
/* .inp_ctx = */ nullptr,
|
|
3475
|
+
/* .inp_buf = */ nullptr,
|
|
2389
3476
|
}
|
|
2390
3477
|
);
|
|
2391
3478
|
}
|
|
@@ -2541,8 +3628,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2541
3628
|
if (n_non_eog == 0) {
|
|
2542
3629
|
cur_p->size = 1;
|
|
2543
3630
|
cur_p->data[0].id = ctx->vocab->token_eot();
|
|
3631
|
+
if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
|
|
3632
|
+
cur_p->data[0].id = ctx->vocab->token_eos();
|
|
3633
|
+
}
|
|
2544
3634
|
cur_p->data[0].logit = 1.0f;
|
|
2545
3635
|
|
|
3636
|
+
GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
|
|
3637
|
+
|
|
2546
3638
|
return;
|
|
2547
3639
|
}
|
|
2548
3640
|
|
|
@@ -2593,12 +3685,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
|
|
2593
3685
|
}
|
|
2594
3686
|
|
|
2595
3687
|
static struct llama_sampler_i llama_sampler_infill_i = {
|
|
2596
|
-
/* .name
|
|
2597
|
-
/* .accept
|
|
2598
|
-
/* .apply
|
|
2599
|
-
/* .reset
|
|
2600
|
-
/* .clone
|
|
2601
|
-
/* .free
|
|
3688
|
+
/* .name = */ llama_sampler_infill_name,
|
|
3689
|
+
/* .accept = */ nullptr,
|
|
3690
|
+
/* .apply = */ llama_sampler_infill_apply,
|
|
3691
|
+
/* .reset = */ nullptr,
|
|
3692
|
+
/* .clone = */ llama_sampler_infill_clone,
|
|
3693
|
+
/* .free = */ llama_sampler_infill_free,
|
|
3694
|
+
/* .backend_apply = */ nullptr,
|
|
3695
|
+
/* .backend_accept = */ nullptr,
|
|
3696
|
+
/* .backend_set_input = */ nullptr,
|
|
3697
|
+
/* .backend_init = */ nullptr,
|
|
2602
3698
|
};
|
|
2603
3699
|
|
|
2604
3700
|
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
|
@@ -2630,7 +3726,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|
|
2630
3726
|
if (smpl->iface == &llama_sampler_chain_i) {
|
|
2631
3727
|
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
|
2632
3728
|
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
|
2633
|
-
const uint32_t seed = llama_sampler_get_seed(
|
|
3729
|
+
const uint32_t seed = llama_sampler_get_seed(it->ptr);
|
|
2634
3730
|
if (seed != LLAMA_DEFAULT_SEED) {
|
|
2635
3731
|
return seed;
|
|
2636
3732
|
}
|
|
@@ -2660,8 +3756,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
|
|
|
2660
3756
|
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
|
2661
3757
|
const auto data = llama_perf_sampler(chain);
|
|
2662
3758
|
|
|
2663
|
-
LLAMA_LOG_INFO("%s:
|
|
2664
|
-
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
|
3759
|
+
LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
|
|
2665
3760
|
}
|
|
2666
3761
|
|
|
2667
3762
|
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
@@ -2671,5 +3766,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
|
2671
3766
|
|
|
2672
3767
|
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
|
2673
3768
|
|
|
2674
|
-
ctx->t_sample_us =
|
|
3769
|
+
ctx->t_sample_us = 0;
|
|
3770
|
+
ctx->n_sample = 0;
|
|
2675
3771
|
}
|