whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -1,9 +1,12 @@
|
|
|
1
|
-
#include "llama-
|
|
1
|
+
#include "llama-sampler.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
4
|
#include "llama-vocab.h"
|
|
5
5
|
#include "llama-grammar.h"
|
|
6
6
|
|
|
7
|
+
#include "ggml-cpp.h"
|
|
8
|
+
|
|
9
|
+
#include <array>
|
|
7
10
|
#include <algorithm>
|
|
8
11
|
#include <cassert>
|
|
9
12
|
#include <cfloat>
|
|
@@ -345,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
|
|
345
348
|
|
|
346
349
|
// llama_sampler API
|
|
347
350
|
|
|
348
|
-
struct llama_sampler * llama_sampler_init(
|
|
351
|
+
struct llama_sampler * llama_sampler_init(
|
|
352
|
+
struct llama_sampler_i * iface,
|
|
353
|
+
llama_sampler_context_t ctx) {
|
|
349
354
|
return new llama_sampler {
|
|
350
355
|
/* .iface = */ iface,
|
|
351
356
|
/* .ctx = */ ctx,
|
|
@@ -361,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
|
|
361
366
|
}
|
|
362
367
|
|
|
363
368
|
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|
369
|
+
if (!smpl) {
|
|
370
|
+
return;
|
|
371
|
+
}
|
|
372
|
+
|
|
364
373
|
if (smpl->iface->accept) {
|
|
365
374
|
smpl->iface->accept(smpl, token);
|
|
366
375
|
}
|
|
367
376
|
}
|
|
368
377
|
|
|
369
378
|
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
|
379
|
+
if (!smpl) {
|
|
380
|
+
return;
|
|
381
|
+
}
|
|
382
|
+
|
|
370
383
|
GGML_ASSERT(smpl->iface->apply);
|
|
371
384
|
smpl->iface->apply(smpl, cur_p);
|
|
372
385
|
}
|
|
373
386
|
|
|
374
387
|
void llama_sampler_reset(struct llama_sampler * smpl) {
|
|
388
|
+
if (!smpl) {
|
|
389
|
+
return;
|
|
390
|
+
}
|
|
391
|
+
|
|
375
392
|
if (smpl->iface->reset) {
|
|
376
393
|
smpl->iface->reset(smpl);
|
|
377
394
|
}
|
|
378
395
|
}
|
|
379
396
|
|
|
380
397
|
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
|
398
|
+
if (!smpl) {
|
|
399
|
+
return nullptr;
|
|
400
|
+
}
|
|
401
|
+
|
|
381
402
|
if (smpl->iface->clone) {
|
|
382
403
|
return smpl->iface->clone(smpl);
|
|
383
404
|
}
|
|
@@ -404,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|
|
404
425
|
delete smpl;
|
|
405
426
|
}
|
|
406
427
|
|
|
407
|
-
|
|
408
|
-
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
428
|
+
// empty sampler
|
|
409
429
|
|
|
410
|
-
|
|
411
|
-
const
|
|
430
|
+
struct llama_sampler_empty {
|
|
431
|
+
const char * name;
|
|
432
|
+
};
|
|
412
433
|
|
|
413
|
-
|
|
434
|
+
static struct llama_sampler * llama_sampler_init_empty(const char * name);
|
|
435
|
+
|
|
436
|
+
static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
|
|
437
|
+
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
|
438
|
+
return ctx->name;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
|
|
442
|
+
GGML_UNUSED(smpl);
|
|
443
|
+
GGML_UNUSED(token);
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
447
|
+
GGML_UNUSED(smpl);
|
|
448
|
+
GGML_UNUSED(cur_p);
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
|
|
452
|
+
GGML_UNUSED(smpl);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
|
|
456
|
+
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
|
457
|
+
return llama_sampler_init_empty(ctx->name);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
static void llama_sampler_empty_free(struct llama_sampler * smpl) {
|
|
461
|
+
delete (llama_sampler_empty *) smpl->ctx;
|
|
462
|
+
}
|
|
414
463
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
464
|
+
static bool llama_sampler_empty_backend_init(
|
|
465
|
+
struct llama_sampler * smpl,
|
|
466
|
+
ggml_backend_buffer_type_t buft) {
|
|
467
|
+
GGML_UNUSED(smpl);
|
|
468
|
+
GGML_UNUSED(buft);
|
|
469
|
+
|
|
470
|
+
return true;
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
static void llama_sampler_empty_backend_accept(
|
|
474
|
+
struct llama_sampler * smpl,
|
|
475
|
+
ggml_context * ctx,
|
|
476
|
+
ggml_cgraph * gf,
|
|
477
|
+
struct ggml_tensor * selected_token) {
|
|
478
|
+
GGML_UNUSED(smpl);
|
|
479
|
+
GGML_UNUSED(ctx);
|
|
480
|
+
GGML_UNUSED(gf);
|
|
481
|
+
GGML_UNUSED(selected_token);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
static void llama_sampler_empty_backend_apply(
|
|
485
|
+
struct llama_sampler * smpl,
|
|
486
|
+
struct ggml_context * ctx,
|
|
487
|
+
struct ggml_cgraph * gf,
|
|
488
|
+
struct llama_sampler_data * data) {
|
|
489
|
+
GGML_UNUSED(smpl);
|
|
490
|
+
GGML_UNUSED(ctx);
|
|
491
|
+
GGML_UNUSED(gf);
|
|
492
|
+
GGML_UNUSED(data);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
|
|
496
|
+
GGML_UNUSED(smpl);
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
static struct llama_sampler_i llama_sampler_empty_i = {
|
|
500
|
+
/* .name = */ llama_sampler_empty_name,
|
|
501
|
+
/* .accept = */ llama_sampler_empty_accept,
|
|
502
|
+
/* .apply = */ llama_sampler_empty_apply,
|
|
503
|
+
/* .reset = */ llama_sampler_empty_reset,
|
|
504
|
+
/* .clone = */ llama_sampler_empty_clone,
|
|
505
|
+
/* .free = */ llama_sampler_empty_free,
|
|
506
|
+
/* .backend_init = */ llama_sampler_empty_backend_init,
|
|
507
|
+
/* .backend_accept = */ llama_sampler_empty_backend_accept,
|
|
508
|
+
/* .backend_apply = */ llama_sampler_empty_backend_apply,
|
|
509
|
+
/* .backend_set_input = */ llama_sampler_empty_backend_set_input,
|
|
510
|
+
};
|
|
511
|
+
|
|
512
|
+
struct llama_sampler * llama_sampler_init_empty(const char * name) {
|
|
513
|
+
return llama_sampler_init(
|
|
514
|
+
/* .iface = */ &llama_sampler_empty_i,
|
|
515
|
+
/* .ctx = */ new llama_sampler_empty {
|
|
516
|
+
/* .name = */ name,
|
|
517
|
+
}
|
|
518
|
+
);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
// common backend sampler functionality
|
|
522
|
+
//
|
|
523
|
+
// +name : means that the sampler is support and will run on the backend
|
|
524
|
+
// -name : means that a ggml operator is not supported by the backend
|
|
525
|
+
//
|
|
526
|
+
struct llama_sampler_backend {
|
|
527
|
+
llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
|
|
528
|
+
|
|
529
|
+
const char * get_name() {
|
|
530
|
+
if (!is_init) {
|
|
531
|
+
return name.c_str();
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
if (support) {
|
|
535
|
+
name_ext = "+" + name;
|
|
536
|
+
} else {
|
|
537
|
+
name_ext = "-" + name;
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
return name_ext.c_str();
|
|
420
541
|
}
|
|
421
542
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
543
|
+
void init(bool support) {
|
|
544
|
+
GGML_ASSERT(this->is_init == false);
|
|
545
|
+
|
|
546
|
+
this->is_init = true;
|
|
547
|
+
this->support = support;
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
private:
|
|
551
|
+
std::string name;
|
|
552
|
+
std::string name_ext;
|
|
553
|
+
|
|
554
|
+
bool is_init;
|
|
555
|
+
bool support;
|
|
556
|
+
};
|
|
557
|
+
|
|
558
|
+
// check if all ggml ops used by the sampler are supported by the backend
|
|
559
|
+
static bool llama_sampler_backend_support(
|
|
560
|
+
llama_sampler * smpl,
|
|
561
|
+
ggml_backend_buffer_type_t buft) {
|
|
562
|
+
auto * device = ggml_backend_buft_get_device(buft);
|
|
563
|
+
if (!device) {
|
|
564
|
+
// CPU backend always supported
|
|
565
|
+
return true;
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
ggml_init_params params = {
|
|
569
|
+
/*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
|
|
570
|
+
/*.mem_buffer =*/ NULL,
|
|
571
|
+
/*.no_alloc =*/ true,
|
|
427
572
|
};
|
|
428
573
|
|
|
429
|
-
|
|
574
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
575
|
+
if (!ctx_ptr) {
|
|
576
|
+
throw std::runtime_error(format("failed to create ggml context"));
|
|
577
|
+
}
|
|
430
578
|
|
|
431
|
-
|
|
579
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
432
580
|
|
|
433
|
-
|
|
581
|
+
const int64_t n = 1024*1024;
|
|
434
582
|
|
|
435
|
-
|
|
583
|
+
llama_sampler_data data = {
|
|
584
|
+
/*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
|
|
585
|
+
/*.probs = */ nullptr,
|
|
586
|
+
/*.sampled = */ nullptr,
|
|
587
|
+
/*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
|
|
588
|
+
};
|
|
436
589
|
|
|
437
|
-
|
|
590
|
+
ggml_cgraph * gf = ggml_new_graph(ctx);
|
|
591
|
+
|
|
592
|
+
smpl->iface->backend_apply(smpl, ctx, gf, &data);
|
|
593
|
+
|
|
594
|
+
if (data.logits) {
|
|
595
|
+
ggml_build_forward_expand(gf, data.logits);
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
if (data.probs) {
|
|
599
|
+
ggml_build_forward_expand(gf, data.probs);
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
if (data.sampled) {
|
|
603
|
+
ggml_build_forward_expand(gf, data.sampled);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
if (data.candidates) {
|
|
607
|
+
ggml_build_forward_expand(gf, data.candidates);
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
611
|
+
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
|
612
|
+
|
|
613
|
+
if (!ggml_backend_dev_supports_op(device, op)) {
|
|
614
|
+
LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
|
|
615
|
+
__func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
|
|
616
|
+
|
|
617
|
+
return false;
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
return true;
|
|
438
622
|
}
|
|
439
623
|
|
|
440
624
|
// sampler chain
|
|
@@ -448,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
|
|
|
448
632
|
|
|
449
633
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
450
634
|
|
|
451
|
-
for (auto
|
|
452
|
-
llama_sampler_accept(smpl, token);
|
|
635
|
+
for (auto & smpl : chain->samplers) {
|
|
636
|
+
llama_sampler_accept(smpl.ptr, token);
|
|
453
637
|
}
|
|
454
638
|
|
|
455
639
|
chain->n_sample++;
|
|
@@ -460,20 +644,29 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
|
|
|
460
644
|
|
|
461
645
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
462
646
|
|
|
463
|
-
|
|
464
|
-
|
|
647
|
+
bool is_backend = chain->is_init;
|
|
648
|
+
|
|
649
|
+
for (auto & smpl : chain->samplers) {
|
|
650
|
+
if (is_backend && smpl.is_backend) {
|
|
651
|
+
continue;
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
is_backend = false;
|
|
655
|
+
|
|
656
|
+
if (smpl.ptr->iface->apply == nullptr) {
|
|
657
|
+
continue;
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
llama_sampler_apply(smpl.ptr, cur_p);
|
|
465
661
|
}
|
|
466
662
|
}
|
|
467
663
|
|
|
468
664
|
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
|
469
665
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
470
666
|
|
|
471
|
-
for (auto
|
|
472
|
-
llama_sampler_reset(smpl);
|
|
667
|
+
for (auto & smpl : chain->samplers) {
|
|
668
|
+
llama_sampler_reset(smpl.ptr);
|
|
473
669
|
}
|
|
474
|
-
|
|
475
|
-
chain->t_sample_us = 0;
|
|
476
|
-
chain->n_sample = 0;
|
|
477
670
|
}
|
|
478
671
|
|
|
479
672
|
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
|
@@ -481,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
|
|
481
674
|
|
|
482
675
|
auto * result = llama_sampler_chain_init(chain_src->params);
|
|
483
676
|
|
|
484
|
-
for (auto
|
|
485
|
-
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
|
677
|
+
for (const auto & smpl : chain_src->samplers) {
|
|
678
|
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
|
|
486
679
|
}
|
|
487
680
|
|
|
488
681
|
return result;
|
|
@@ -491,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
|
|
491
684
|
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
|
492
685
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
493
686
|
|
|
494
|
-
for (auto
|
|
495
|
-
llama_sampler_free(smpl);
|
|
687
|
+
for (auto & smpl : chain->samplers) {
|
|
688
|
+
llama_sampler_free(smpl.ptr);
|
|
496
689
|
}
|
|
497
690
|
|
|
498
691
|
delete chain;
|
|
499
692
|
}
|
|
500
693
|
|
|
694
|
+
static bool llama_sampler_chain_backend_init(
|
|
695
|
+
struct llama_sampler * smpl,
|
|
696
|
+
ggml_backend_buffer_type_t buft) {
|
|
697
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
698
|
+
|
|
699
|
+
GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
|
|
700
|
+
|
|
701
|
+
chain->is_init = true;
|
|
702
|
+
|
|
703
|
+
bool res = true;
|
|
704
|
+
|
|
705
|
+
for (auto & smpl : chain->samplers) {
|
|
706
|
+
bool res_cur = true;
|
|
707
|
+
|
|
708
|
+
// to be able to run a sampler on the backend, it has to:
|
|
709
|
+
// - have the .backend_init() API implemented
|
|
710
|
+
// - return true during .backend_init()
|
|
711
|
+
if (smpl.ptr->iface->backend_init) {
|
|
712
|
+
if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
|
|
713
|
+
res_cur = false;
|
|
714
|
+
}
|
|
715
|
+
} else {
|
|
716
|
+
res_cur = false;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
smpl.is_backend = res_cur;
|
|
720
|
+
|
|
721
|
+
res = res && res_cur;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
return res;
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
static void llama_sampler_chain_backend_accept(
|
|
728
|
+
struct llama_sampler * smpl,
|
|
729
|
+
ggml_context * ctx,
|
|
730
|
+
ggml_cgraph * gf,
|
|
731
|
+
struct ggml_tensor * selected_token) {
|
|
732
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
733
|
+
|
|
734
|
+
for (auto & smpl : chain->samplers) {
|
|
735
|
+
if (!smpl.is_backend) {
|
|
736
|
+
break;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
if (smpl.ptr->iface->backend_accept) {
|
|
740
|
+
smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
|
|
741
|
+
}
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
static void llama_sampler_chain_backend_apply(
|
|
746
|
+
struct llama_sampler * smpl,
|
|
747
|
+
struct ggml_context * ctx,
|
|
748
|
+
struct ggml_cgraph * gf,
|
|
749
|
+
struct llama_sampler_data * data) {
|
|
750
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
751
|
+
|
|
752
|
+
GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
|
|
753
|
+
|
|
754
|
+
for (auto & smpl : chain->samplers) {
|
|
755
|
+
if (!smpl.is_backend) {
|
|
756
|
+
break;
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
if (smpl.ptr->iface->backend_apply) {
|
|
760
|
+
smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
|
|
766
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
767
|
+
|
|
768
|
+
for (auto & smpl : chain->samplers) {
|
|
769
|
+
if (!smpl.is_backend) {
|
|
770
|
+
break;
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
if (smpl.ptr->iface->backend_set_input) {
|
|
774
|
+
smpl.ptr->iface->backend_set_input(smpl.ptr);
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
|
|
501
779
|
static struct llama_sampler_i llama_sampler_chain_i = {
|
|
502
|
-
/* .name
|
|
503
|
-
/* .accept
|
|
504
|
-
/* .apply
|
|
505
|
-
/* .reset
|
|
506
|
-
/* .clone
|
|
507
|
-
/* .free
|
|
780
|
+
/* .name = */ llama_sampler_chain_name,
|
|
781
|
+
/* .accept = */ llama_sampler_chain_accept,
|
|
782
|
+
/* .apply = */ llama_sampler_chain_apply,
|
|
783
|
+
/* .reset = */ llama_sampler_chain_reset,
|
|
784
|
+
/* .clone = */ llama_sampler_chain_clone,
|
|
785
|
+
/* .free = */ llama_sampler_chain_free,
|
|
786
|
+
/* .backend_init = */ llama_sampler_chain_backend_init,
|
|
787
|
+
/* .backend_accept = */ llama_sampler_chain_backend_accept,
|
|
788
|
+
/* .backend_apply = */ llama_sampler_chain_backend_apply,
|
|
789
|
+
/* .backend_set_input = */ llama_sampler_chain_backend_set_input,
|
|
508
790
|
};
|
|
509
791
|
|
|
510
792
|
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
|
@@ -512,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
|
|
512
794
|
/* .iface = */ &llama_sampler_chain_i,
|
|
513
795
|
/* .ctx = */ new llama_sampler_chain {
|
|
514
796
|
/* .params = */ params,
|
|
797
|
+
/* .is_init = */ false,
|
|
515
798
|
/* .samplers = */ {},
|
|
799
|
+
/* .cur = */ {},
|
|
516
800
|
/* .t_sample_us = */ 0,
|
|
517
801
|
/* .n_sample = */ 0,
|
|
518
802
|
}
|
|
519
803
|
);
|
|
520
804
|
}
|
|
521
805
|
|
|
806
|
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
807
|
+
const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
|
|
808
|
+
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
|
809
|
+
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
|
810
|
+
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
|
811
|
+
|
|
812
|
+
// If a backend sampler has already sampled a token, return it.
|
|
813
|
+
if (sampled_token != LLAMA_TOKEN_NULL) {
|
|
814
|
+
LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
|
|
815
|
+
return sampled_token;
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
const llama_model * model = llama_get_model(ctx);
|
|
819
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
820
|
+
|
|
821
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
822
|
+
|
|
823
|
+
// use pre-allocated buffer from chain if available, otherwise allocate locally
|
|
824
|
+
std::vector<llama_token_data> * cur_ptr;
|
|
825
|
+
std::vector<llama_token_data> cur_local;
|
|
826
|
+
|
|
827
|
+
if (smpl->iface == &llama_sampler_chain_i) {
|
|
828
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
829
|
+
cur_ptr = &chain->cur;
|
|
830
|
+
} else {
|
|
831
|
+
cur_ptr = &cur_local;
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
auto & cur = *cur_ptr;
|
|
835
|
+
|
|
836
|
+
if (sampled_probs) {
|
|
837
|
+
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
|
838
|
+
cur.resize(sampled_probs_count);
|
|
839
|
+
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
|
840
|
+
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
|
841
|
+
}
|
|
842
|
+
} else if (sampled_logits) {
|
|
843
|
+
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
|
844
|
+
cur.resize(sampled_logits_count);
|
|
845
|
+
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
|
846
|
+
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
|
847
|
+
}
|
|
848
|
+
} else {
|
|
849
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
850
|
+
GGML_ASSERT(logits != nullptr);
|
|
851
|
+
cur.resize(n_vocab);
|
|
852
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
853
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
llama_token_data_array cur_p = {
|
|
858
|
+
/* .data = */ cur.data(),
|
|
859
|
+
/* .size = */ cur.size(),
|
|
860
|
+
/* .selected = */ -1,
|
|
861
|
+
/* .sorted = */ false,
|
|
862
|
+
};
|
|
863
|
+
|
|
864
|
+
llama_sampler_apply(smpl, &cur_p);
|
|
865
|
+
|
|
866
|
+
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
|
867
|
+
|
|
868
|
+
auto token = cur_p.data[cur_p.selected].id;
|
|
869
|
+
|
|
870
|
+
llama_sampler_accept(smpl, token);
|
|
871
|
+
|
|
872
|
+
return token;
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
|
|
522
876
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
523
877
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
|
524
|
-
p->samplers.push_back(
|
|
878
|
+
p->samplers.push_back({
|
|
879
|
+
/* .is_backend = */ false,
|
|
880
|
+
/* .ptr = */ smpl,
|
|
881
|
+
});
|
|
525
882
|
}
|
|
526
883
|
|
|
527
|
-
struct llama_sampler * llama_sampler_chain_get(
|
|
884
|
+
struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
|
|
885
|
+
if (chain == nullptr) {
|
|
886
|
+
return nullptr;
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
if (chain->iface != &llama_sampler_chain_i) {
|
|
890
|
+
return nullptr;
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
if (i == -1) {
|
|
894
|
+
return chain;
|
|
895
|
+
}
|
|
896
|
+
|
|
528
897
|
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
|
529
898
|
|
|
530
899
|
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
|
531
900
|
return nullptr;
|
|
532
901
|
}
|
|
533
902
|
|
|
534
|
-
return p->samplers[i];
|
|
903
|
+
return p->samplers[i].ptr;
|
|
535
904
|
}
|
|
536
905
|
|
|
537
906
|
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
|
@@ -541,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
|
|
|
541
910
|
return nullptr;
|
|
542
911
|
}
|
|
543
912
|
|
|
544
|
-
auto * result = p->samplers[i];
|
|
913
|
+
auto * result = p->samplers[i].ptr;
|
|
545
914
|
p->samplers.erase(p->samplers.begin() + i);
|
|
546
915
|
|
|
547
916
|
return result;
|
|
@@ -559,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
|
|
559
928
|
|
|
560
929
|
// greedy
|
|
561
930
|
|
|
562
|
-
|
|
563
|
-
|
|
931
|
+
struct llama_sampler_greedy : public llama_sampler_backend {
|
|
932
|
+
};
|
|
933
|
+
|
|
934
|
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
|
|
935
|
+
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
|
936
|
+
return sctx->get_name();
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
|
|
940
|
+
auto * ctx = (llama_sampler_greedy *) smpl->ctx;
|
|
941
|
+
GGML_UNUSED(ctx);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
|
|
945
|
+
const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
|
|
946
|
+
auto * result = llama_sampler_init_greedy();
|
|
947
|
+
|
|
948
|
+
// copy the state
|
|
949
|
+
{
|
|
950
|
+
auto * result_ctx = (llama_sampler_greedy *) result->ctx;
|
|
951
|
+
|
|
952
|
+
GGML_UNUSED(ctx);
|
|
953
|
+
GGML_UNUSED(result_ctx);
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
return result;
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
|
|
960
|
+
delete (llama_sampler_greedy *) smpl->ctx;
|
|
564
961
|
}
|
|
565
962
|
|
|
566
963
|
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
|
@@ -572,33 +969,68 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
|
|
572
969
|
}
|
|
573
970
|
}
|
|
574
971
|
|
|
972
|
+
static bool llama_sampler_greedy_backend_init(
|
|
973
|
+
struct llama_sampler * smpl,
|
|
974
|
+
ggml_backend_buffer_type_t buft) {
|
|
975
|
+
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
|
976
|
+
|
|
977
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
978
|
+
|
|
979
|
+
sctx->init(res);
|
|
980
|
+
|
|
981
|
+
return res;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
static void llama_sampler_greedy_backend_apply(
|
|
985
|
+
struct llama_sampler * smpl,
|
|
986
|
+
struct ggml_context * ctx,
|
|
987
|
+
struct ggml_cgraph * gf,
|
|
988
|
+
struct llama_sampler_data * data) {
|
|
989
|
+
GGML_UNUSED(gf);
|
|
990
|
+
GGML_UNUSED(smpl);
|
|
991
|
+
|
|
992
|
+
struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
|
|
993
|
+
ggml_set_name(curl, "greedy_argmax");
|
|
994
|
+
|
|
995
|
+
data->sampled = curl;
|
|
996
|
+
}
|
|
997
|
+
|
|
575
998
|
static struct llama_sampler_i llama_sampler_greedy_i = {
|
|
576
|
-
/* .name
|
|
577
|
-
/* .accept
|
|
578
|
-
/* .apply
|
|
579
|
-
/* .reset
|
|
580
|
-
/* .clone
|
|
581
|
-
/* .free
|
|
999
|
+
/* .name = */ llama_sampler_greedy_name,
|
|
1000
|
+
/* .accept = */ nullptr,
|
|
1001
|
+
/* .apply = */ llama_sampler_greedy_apply,
|
|
1002
|
+
/* .reset = */ llama_sampler_greedy_reset,
|
|
1003
|
+
/* .clone = */ llama_sampler_greedy_clone,
|
|
1004
|
+
/* .free = */ llama_sampler_greedy_free,
|
|
1005
|
+
/* .backend_init = */ llama_sampler_greedy_backend_init,
|
|
1006
|
+
/* .backend_accept = */ nullptr,
|
|
1007
|
+
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
|
|
1008
|
+
/* .backend_set_input = */ nullptr,
|
|
582
1009
|
};
|
|
583
1010
|
|
|
584
1011
|
struct llama_sampler * llama_sampler_init_greedy() {
|
|
585
1012
|
return llama_sampler_init(
|
|
586
1013
|
/* .iface = */ &llama_sampler_greedy_i,
|
|
587
|
-
/* .ctx = */
|
|
1014
|
+
/* .ctx = */ new llama_sampler_greedy {
|
|
1015
|
+
("greedy"),
|
|
1016
|
+
}
|
|
588
1017
|
);
|
|
589
1018
|
}
|
|
590
1019
|
|
|
591
1020
|
// dist
|
|
592
1021
|
|
|
593
|
-
struct llama_sampler_dist {
|
|
1022
|
+
struct llama_sampler_dist : public llama_sampler_backend {
|
|
594
1023
|
const uint32_t seed;
|
|
595
1024
|
uint32_t seed_cur;
|
|
596
1025
|
|
|
597
1026
|
std::mt19937 rng;
|
|
1027
|
+
|
|
1028
|
+
ggml_tensor * inp_uniform;
|
|
598
1029
|
};
|
|
599
1030
|
|
|
600
|
-
static const char * llama_sampler_dist_name(const struct llama_sampler *
|
|
601
|
-
|
|
1031
|
+
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
|
|
1032
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1033
|
+
return sctx->get_name();
|
|
602
1034
|
}
|
|
603
1035
|
|
|
604
1036
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -673,6 +1105,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|
|
673
1105
|
#endif
|
|
674
1106
|
}
|
|
675
1107
|
|
|
1108
|
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
1109
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
1110
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
1111
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
1112
|
+
}
|
|
1113
|
+
|
|
676
1114
|
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
|
677
1115
|
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
|
678
1116
|
auto * result = llama_sampler_init_dist(ctx->seed);
|
|
@@ -687,23 +1125,106 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
|
|
|
687
1125
|
return result;
|
|
688
1126
|
}
|
|
689
1127
|
|
|
690
|
-
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
691
|
-
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
692
|
-
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
693
|
-
ctx->rng.seed(ctx->seed_cur);
|
|
694
|
-
}
|
|
695
|
-
|
|
696
1128
|
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|
697
1129
|
delete (llama_sampler_dist *) smpl->ctx;
|
|
698
1130
|
}
|
|
699
1131
|
|
|
1132
|
+
static bool llama_sampler_dist_backend_init(
|
|
1133
|
+
struct llama_sampler * smpl,
|
|
1134
|
+
ggml_backend_buffer_type_t buft) {
|
|
1135
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1136
|
+
|
|
1137
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1138
|
+
|
|
1139
|
+
sctx->init(res);
|
|
1140
|
+
|
|
1141
|
+
return res;
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
static void llama_sampler_dist_backend_apply(
|
|
1145
|
+
struct llama_sampler * smpl,
|
|
1146
|
+
struct ggml_context * ctx,
|
|
1147
|
+
struct ggml_cgraph * gf,
|
|
1148
|
+
struct llama_sampler_data * data) {
|
|
1149
|
+
GGML_UNUSED(gf);
|
|
1150
|
+
|
|
1151
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1152
|
+
|
|
1153
|
+
sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
|
1154
|
+
ggml_set_name (sctx->inp_uniform, "uniform");
|
|
1155
|
+
ggml_set_input(sctx->inp_uniform);
|
|
1156
|
+
|
|
1157
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
1158
|
+
ggml_set_name(probs, "dist_probs");
|
|
1159
|
+
|
|
1160
|
+
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
|
1161
|
+
ggml_set_name(cumsum, "dist_cumsum");
|
|
1162
|
+
|
|
1163
|
+
// The uniform tensor has a random value and we subtract this tensor with
|
|
1164
|
+
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
|
|
1165
|
+
// Recall that each entry in cumsum is the cumulative probability up to that
|
|
1166
|
+
// index so values stay negative while the cumulative total is below the
|
|
1167
|
+
// random value, and become zero/positive once the threshold is crossed.
|
|
1168
|
+
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
|
|
1169
|
+
ggml_set_name(diff, "dist_cumsum");
|
|
1170
|
+
|
|
1171
|
+
// The ggml_step function produces a tensor where entries are 1 if the
|
|
1172
|
+
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
|
|
1173
|
+
// the index where the cumulative probability exceeds the random value are 0,
|
|
1174
|
+
// and all entries after that are 1.
|
|
1175
|
+
struct ggml_tensor * mask = ggml_step(ctx, diff);
|
|
1176
|
+
ggml_set_name(mask, "dist_mask");
|
|
1177
|
+
|
|
1178
|
+
// Taking the sum of the mask gives us the sum of elements after the threshold
|
|
1179
|
+
// we are interested in.
|
|
1180
|
+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
|
1181
|
+
ggml_set_name(idxf, "dist_index_f32");
|
|
1182
|
+
|
|
1183
|
+
// Use ggml_scale_bias to scale the index value by -1 and then add the size
|
|
1184
|
+
// of the mask to that value so we get the correct index ((-1 * idxf) + n).
|
|
1185
|
+
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
|
1186
|
+
ggml_set_name(idx, "dist_index_i32");
|
|
1187
|
+
|
|
1188
|
+
// Map back to original vocab ids if a candidates tensor is available.
|
|
1189
|
+
struct ggml_tensor * sampled_token = idx;
|
|
1190
|
+
if (data->candidates != nullptr) {
|
|
1191
|
+
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
|
|
1192
|
+
|
|
1193
|
+
sampled_token = ggml_get_rows(ctx, candidates, idx);
|
|
1194
|
+
ggml_set_name(sampled_token, "dist_sampled_token");
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
data->sampled = sampled_token;
|
|
1198
|
+
data->probs = probs;
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
|
1202
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1203
|
+
|
|
1204
|
+
GGML_ASSERT(sctx->inp_uniform != nullptr);
|
|
1205
|
+
|
|
1206
|
+
// We sample in double precision and cast to float to match rnd numbers of
|
|
1207
|
+
// llama_dampler_dist which uses double precision (sampling from
|
|
1208
|
+
// std::uniform_real_distribution<double> and
|
|
1209
|
+
// std::uniform_real_distribution<float> with same rng will produce
|
|
1210
|
+
// different sequences).
|
|
1211
|
+
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
|
1212
|
+
const float rnd = dist(sctx->rng);
|
|
1213
|
+
|
|
1214
|
+
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
|
1215
|
+
}
|
|
1216
|
+
|
|
700
1217
|
static struct llama_sampler_i llama_sampler_dist_i = {
|
|
701
|
-
/* .name
|
|
702
|
-
/* .accept
|
|
703
|
-
/* .apply
|
|
704
|
-
/* .reset
|
|
705
|
-
/* .clone
|
|
706
|
-
/* .free
|
|
1218
|
+
/* .name = */ llama_sampler_dist_name,
|
|
1219
|
+
/* .accept = */ nullptr,
|
|
1220
|
+
/* .apply = */ llama_sampler_dist_apply,
|
|
1221
|
+
/* .reset = */ llama_sampler_dist_reset,
|
|
1222
|
+
/* .clone = */ llama_sampler_dist_clone,
|
|
1223
|
+
/* .free = */ llama_sampler_dist_free,
|
|
1224
|
+
/* .backend_init = */ llama_sampler_dist_backend_init,
|
|
1225
|
+
/* .backend_accept = */ nullptr,
|
|
1226
|
+
/* .backend_apply = */ llama_sampler_dist_backend_apply,
|
|
1227
|
+
/* .backend_set_input = */ llama_sampler_dist_backend_set_input,
|
|
707
1228
|
};
|
|
708
1229
|
|
|
709
1230
|
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
@@ -711,21 +1232,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
|
711
1232
|
return llama_sampler_init(
|
|
712
1233
|
/* .iface = */ &llama_sampler_dist_i,
|
|
713
1234
|
/* .ctx = */ new llama_sampler_dist {
|
|
714
|
-
|
|
715
|
-
/* .
|
|
716
|
-
/* .
|
|
1235
|
+
("dist"),
|
|
1236
|
+
/* .seed = */ seed,
|
|
1237
|
+
/* .seed_cur = */ seed_cur,
|
|
1238
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
1239
|
+
/* .inp_uniform = */ nullptr,
|
|
717
1240
|
}
|
|
718
1241
|
);
|
|
719
1242
|
}
|
|
720
1243
|
|
|
721
1244
|
// top-k
|
|
722
1245
|
|
|
723
|
-
struct llama_sampler_top_k {
|
|
1246
|
+
struct llama_sampler_top_k : public llama_sampler_backend {
|
|
724
1247
|
const int32_t k;
|
|
725
1248
|
};
|
|
726
1249
|
|
|
727
|
-
static const char * llama_sampler_top_k_name(const struct llama_sampler *
|
|
728
|
-
|
|
1250
|
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
|
|
1251
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1252
|
+
return sctx->get_name();
|
|
729
1253
|
}
|
|
730
1254
|
|
|
731
1255
|
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -742,19 +1266,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
|
|
742
1266
|
delete (llama_sampler_top_k *) smpl->ctx;
|
|
743
1267
|
}
|
|
744
1268
|
|
|
1269
|
+
static bool llama_sampler_top_k_backend_init(
|
|
1270
|
+
struct llama_sampler * smpl,
|
|
1271
|
+
ggml_backend_buffer_type_t buft) {
|
|
1272
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1273
|
+
|
|
1274
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1275
|
+
|
|
1276
|
+
sctx->init(res);
|
|
1277
|
+
|
|
1278
|
+
return res;
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
static void llama_sampler_top_k_backend_apply(
|
|
1282
|
+
struct llama_sampler * smpl,
|
|
1283
|
+
struct ggml_context * ctx,
|
|
1284
|
+
struct ggml_cgraph * gf,
|
|
1285
|
+
struct llama_sampler_data * data) {
|
|
1286
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1287
|
+
|
|
1288
|
+
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
|
|
1289
|
+
ggml_set_name(top_k, "top_k");
|
|
1290
|
+
|
|
1291
|
+
if (data->candidates) {
|
|
1292
|
+
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
|
1293
|
+
data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
|
|
1294
|
+
data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
|
|
1295
|
+
ggml_set_name(data->candidates, "top_k_candidates");
|
|
1296
|
+
} else {
|
|
1297
|
+
data->candidates = top_k;
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1301
|
+
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
|
1302
|
+
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
|
|
1303
|
+
ggml_set_name(top_k_rows, "top_k_rows");
|
|
1304
|
+
|
|
1305
|
+
GGML_UNUSED(gf);
|
|
1306
|
+
}
|
|
1307
|
+
|
|
745
1308
|
static struct llama_sampler_i llama_sampler_top_k_i = {
|
|
746
|
-
/* .name
|
|
747
|
-
/* .accept
|
|
748
|
-
/* .apply
|
|
749
|
-
/* .reset
|
|
750
|
-
/* .clone
|
|
751
|
-
/* .free
|
|
1309
|
+
/* .name = */ llama_sampler_top_k_name,
|
|
1310
|
+
/* .accept = */ nullptr,
|
|
1311
|
+
/* .apply = */ llama_sampler_top_k_apply,
|
|
1312
|
+
/* .reset = */ nullptr,
|
|
1313
|
+
/* .clone = */ llama_sampler_top_k_clone,
|
|
1314
|
+
/* .free = */ llama_sampler_top_k_free,
|
|
1315
|
+
/* .backend_init = */ llama_sampler_top_k_backend_init,
|
|
1316
|
+
/* .backend_accept = */ nullptr,
|
|
1317
|
+
/* .backend_apply = */ llama_sampler_top_k_backend_apply,
|
|
1318
|
+
/* .backend_set_input = */ nullptr,
|
|
752
1319
|
};
|
|
753
1320
|
|
|
754
1321
|
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
1322
|
+
const bool is_empty = (k <= 0);
|
|
1323
|
+
|
|
1324
|
+
if (is_empty) {
|
|
1325
|
+
return llama_sampler_init_empty("?top-k");
|
|
1326
|
+
}
|
|
1327
|
+
|
|
755
1328
|
return llama_sampler_init(
|
|
756
1329
|
/* .iface = */ &llama_sampler_top_k_i,
|
|
757
1330
|
/* .ctx = */ new llama_sampler_top_k {
|
|
1331
|
+
("top-k"),
|
|
758
1332
|
/* .k = */ k,
|
|
759
1333
|
}
|
|
760
1334
|
);
|
|
@@ -762,15 +1336,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
|
762
1336
|
|
|
763
1337
|
// top-p
|
|
764
1338
|
|
|
765
|
-
struct llama_sampler_top_p {
|
|
1339
|
+
struct llama_sampler_top_p : public llama_sampler_backend {
|
|
766
1340
|
const float p;
|
|
767
1341
|
const size_t min_keep;
|
|
768
1342
|
|
|
769
1343
|
std::vector<llama_token_data> buf_sort;
|
|
770
1344
|
};
|
|
771
1345
|
|
|
772
|
-
static const char * llama_sampler_top_p_name(const struct llama_sampler *
|
|
773
|
-
|
|
1346
|
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
|
|
1347
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1348
|
+
return sctx->get_name();
|
|
774
1349
|
}
|
|
775
1350
|
|
|
776
1351
|
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -837,19 +1412,115 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|
|
837
1412
|
delete (llama_sampler_top_p *) smpl->ctx;
|
|
838
1413
|
}
|
|
839
1414
|
|
|
1415
|
+
static bool llama_sampler_top_p_backend_init(
|
|
1416
|
+
struct llama_sampler * smpl,
|
|
1417
|
+
ggml_backend_buffer_type_t buft) {
|
|
1418
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1419
|
+
|
|
1420
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1421
|
+
|
|
1422
|
+
sctx->init(res);
|
|
1423
|
+
|
|
1424
|
+
return res;
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
static void llama_sampler_top_p_backend_apply(
|
|
1428
|
+
struct llama_sampler * smpl,
|
|
1429
|
+
struct ggml_context * ctx,
|
|
1430
|
+
struct ggml_cgraph * gf,
|
|
1431
|
+
struct llama_sampler_data * data) {
|
|
1432
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1433
|
+
|
|
1434
|
+
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
|
|
1435
|
+
GGML_ASSERT(ggml_nrows(a) == 1);
|
|
1436
|
+
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
|
|
1437
|
+
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
|
|
1438
|
+
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
|
|
1439
|
+
};
|
|
1440
|
+
|
|
1441
|
+
// Get the sorted logits in descending order.
|
|
1442
|
+
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
|
|
1443
|
+
ggml_set_name(sorted_idx, "top_p_sorted_idx");
|
|
1444
|
+
|
|
1445
|
+
// Do the sorting via reshape + get_rows
|
|
1446
|
+
struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
|
|
1447
|
+
ggml_set_name(sorted_logits, "top_p_sorted_logits");
|
|
1448
|
+
|
|
1449
|
+
struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
|
|
1450
|
+
ggml_set_name(softmax, "top_p_softmax");
|
|
1451
|
+
|
|
1452
|
+
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
|
|
1453
|
+
if (data->candidates) {
|
|
1454
|
+
data->candidates = ggml_sort(data->candidates, sorted_idx);
|
|
1455
|
+
} else {
|
|
1456
|
+
data->candidates = sorted_idx;
|
|
1457
|
+
}
|
|
1458
|
+
ggml_set_name(data->candidates, "top_p_candidates");
|
|
1459
|
+
|
|
1460
|
+
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
|
|
1461
|
+
struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
|
|
1462
|
+
ggml_set_name(cdf, "top_p_cdf");
|
|
1463
|
+
|
|
1464
|
+
// Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
|
|
1465
|
+
struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
|
|
1466
|
+
ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
|
|
1467
|
+
|
|
1468
|
+
struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
|
|
1469
|
+
ggml_set_name(mask, "top_p_mask");
|
|
1470
|
+
|
|
1471
|
+
// Taking the sum of the mask gives us the sum of elements after the threshold
|
|
1472
|
+
// we are interested in.
|
|
1473
|
+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
|
1474
|
+
ggml_set_name(idxf, "top_p_index_f32");
|
|
1475
|
+
|
|
1476
|
+
// prevent out-of-bounds access
|
|
1477
|
+
idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
|
|
1478
|
+
|
|
1479
|
+
// construct ones tensor to set the value in the mask
|
|
1480
|
+
struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
|
|
1481
|
+
ggml_set_name(ones, "top_p_ones");
|
|
1482
|
+
|
|
1483
|
+
// Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
|
|
1484
|
+
struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
|
|
1485
|
+
|
|
1486
|
+
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
|
|
1487
|
+
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
|
|
1488
|
+
|
|
1489
|
+
// Apply -INFINITY bias for masked-out tokens
|
|
1490
|
+
// log(1) = 0 (keep), log(0) = -INF (discard)
|
|
1491
|
+
struct ggml_tensor * top_p_bias = ggml_log(ctx, mask);
|
|
1492
|
+
ggml_set_name(top_p_bias, "top_p_bias");
|
|
1493
|
+
|
|
1494
|
+
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
|
|
1495
|
+
ggml_set_name(data->logits, "top_p_logits");
|
|
1496
|
+
|
|
1497
|
+
GGML_UNUSED(gf);
|
|
1498
|
+
}
|
|
1499
|
+
|
|
840
1500
|
static struct llama_sampler_i llama_sampler_top_p_i = {
|
|
841
|
-
/* .name
|
|
842
|
-
/* .accept
|
|
843
|
-
/* .apply
|
|
844
|
-
/* .reset
|
|
845
|
-
/* .clone
|
|
846
|
-
/* .free
|
|
1501
|
+
/* .name = */ llama_sampler_top_p_name,
|
|
1502
|
+
/* .accept = */ nullptr,
|
|
1503
|
+
/* .apply = */ llama_sampler_top_p_apply,
|
|
1504
|
+
/* .reset = */ nullptr,
|
|
1505
|
+
/* .clone = */ llama_sampler_top_p_clone,
|
|
1506
|
+
/* .free = */ llama_sampler_top_p_free,
|
|
1507
|
+
/* .backend_init = */ llama_sampler_top_p_backend_init,
|
|
1508
|
+
/* .backend_accept = */ nullptr,
|
|
1509
|
+
/* .backend_apply = */ llama_sampler_top_p_backend_apply,
|
|
1510
|
+
/* .backend_set_input = */ nullptr,
|
|
847
1511
|
};
|
|
848
1512
|
|
|
849
1513
|
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
|
1514
|
+
const bool is_empty = p >= 1.0f;
|
|
1515
|
+
|
|
1516
|
+
if (is_empty) {
|
|
1517
|
+
return llama_sampler_init_empty("?top-p");
|
|
1518
|
+
}
|
|
1519
|
+
|
|
850
1520
|
return llama_sampler_init(
|
|
851
1521
|
/* .iface = */ &llama_sampler_top_p_i,
|
|
852
1522
|
/* .ctx = */ new llama_sampler_top_p {
|
|
1523
|
+
("top-p"),
|
|
853
1524
|
/* .p = */ p,
|
|
854
1525
|
/* .min_keep = */ min_keep,
|
|
855
1526
|
/* .buf_sort = */ {},
|
|
@@ -859,13 +1530,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
|
|
859
1530
|
|
|
860
1531
|
// min-p
|
|
861
1532
|
|
|
862
|
-
struct llama_sampler_min_p {
|
|
1533
|
+
struct llama_sampler_min_p : public llama_sampler_backend {
|
|
863
1534
|
const float p;
|
|
864
1535
|
const size_t min_keep;
|
|
865
1536
|
};
|
|
866
1537
|
|
|
867
|
-
static const char * llama_sampler_min_p_name(const struct llama_sampler *
|
|
868
|
-
|
|
1538
|
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
|
|
1539
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1540
|
+
return sctx->get_name();
|
|
869
1541
|
}
|
|
870
1542
|
|
|
871
1543
|
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -931,19 +1603,81 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|
|
931
1603
|
delete (llama_sampler_min_p *) smpl->ctx;
|
|
932
1604
|
}
|
|
933
1605
|
|
|
1606
|
+
static bool llama_sampler_min_p_backend_init(
|
|
1607
|
+
struct llama_sampler * smpl,
|
|
1608
|
+
ggml_backend_buffer_type_t buft) {
|
|
1609
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1610
|
+
|
|
1611
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1612
|
+
|
|
1613
|
+
sctx->init(res);
|
|
1614
|
+
|
|
1615
|
+
return res;
|
|
1616
|
+
}
|
|
1617
|
+
|
|
1618
|
+
static void llama_sampler_min_p_backend_apply(
|
|
1619
|
+
struct llama_sampler * smpl,
|
|
1620
|
+
struct ggml_context * ctx,
|
|
1621
|
+
struct ggml_cgraph * gf,
|
|
1622
|
+
struct llama_sampler_data * data) {
|
|
1623
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1624
|
+
|
|
1625
|
+
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
|
1626
|
+
ggml_set_name(max_idx, "max_idx");
|
|
1627
|
+
|
|
1628
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1629
|
+
ggml_set_name(logits_rows, "logits_rows");
|
|
1630
|
+
|
|
1631
|
+
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
1632
|
+
ggml_set_name(max_logit, "max_logit");
|
|
1633
|
+
|
|
1634
|
+
// Calculate the threshold value.
|
|
1635
|
+
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
|
|
1636
|
+
ggml_set_name(threshold, "min_p_threshold");
|
|
1637
|
+
|
|
1638
|
+
// Subtract the threshold from logits.
|
|
1639
|
+
struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
|
|
1640
|
+
|
|
1641
|
+
// Create a mask where logits below the threshold are 0 (discard),
|
|
1642
|
+
// and others are 1 (keep).
|
|
1643
|
+
struct ggml_tensor * mask = ggml_step(ctx, sub);
|
|
1644
|
+
ggml_set_name(mask, "min_p_mask");
|
|
1645
|
+
|
|
1646
|
+
// Apply -INFINITY bias for masked-out tokens
|
|
1647
|
+
// log(1) = 0 (keep), log(0) = -INF (discard)
|
|
1648
|
+
struct ggml_tensor * min_p_bias = ggml_log(ctx, mask);
|
|
1649
|
+
ggml_set_name(min_p_bias, "min_p_bias");
|
|
1650
|
+
|
|
1651
|
+
data->logits = ggml_add(ctx, data->logits, min_p_bias);
|
|
1652
|
+
ggml_set_name(data->logits, "min_p_logits");
|
|
1653
|
+
|
|
1654
|
+
GGML_UNUSED(gf);
|
|
1655
|
+
}
|
|
1656
|
+
|
|
934
1657
|
static struct llama_sampler_i llama_sampler_min_p_i = {
|
|
935
|
-
/* .name
|
|
936
|
-
/* .accept
|
|
937
|
-
/* .apply
|
|
938
|
-
/* .reset
|
|
939
|
-
/* .clone
|
|
940
|
-
/* .free
|
|
1658
|
+
/* .name = */ llama_sampler_min_p_name,
|
|
1659
|
+
/* .accept = */ nullptr,
|
|
1660
|
+
/* .apply = */ llama_sampler_min_p_apply,
|
|
1661
|
+
/* .reset = */ nullptr,
|
|
1662
|
+
/* .clone = */ llama_sampler_min_p_clone,
|
|
1663
|
+
/* .free = */ llama_sampler_min_p_free,
|
|
1664
|
+
/* .backend_init = */ llama_sampler_min_p_backend_init,
|
|
1665
|
+
/* .backend_accept = */ nullptr,
|
|
1666
|
+
/* .backend_apply = */ llama_sampler_min_p_backend_apply,
|
|
1667
|
+
/* .backend_set_input = */ nullptr,
|
|
941
1668
|
};
|
|
942
1669
|
|
|
943
1670
|
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
|
1671
|
+
const bool is_empty = (p <= 0.0f);
|
|
1672
|
+
|
|
1673
|
+
if (is_empty) {
|
|
1674
|
+
return llama_sampler_init_empty("?min-p");
|
|
1675
|
+
}
|
|
1676
|
+
|
|
944
1677
|
return llama_sampler_init(
|
|
945
1678
|
/* .iface = */ &llama_sampler_min_p_i,
|
|
946
1679
|
/* .ctx = */ new llama_sampler_min_p {
|
|
1680
|
+
("min-p"),
|
|
947
1681
|
/* .p = */ p,
|
|
948
1682
|
/* .min_keep = */ min_keep,
|
|
949
1683
|
}
|
|
@@ -1031,15 +1765,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
|
|
1031
1765
|
}
|
|
1032
1766
|
|
|
1033
1767
|
static struct llama_sampler_i llama_sampler_typical_i = {
|
|
1034
|
-
/* .name
|
|
1035
|
-
/* .accept
|
|
1036
|
-
/* .apply
|
|
1037
|
-
/* .reset
|
|
1038
|
-
/* .clone
|
|
1039
|
-
/* .free
|
|
1768
|
+
/* .name = */ llama_sampler_typical_name,
|
|
1769
|
+
/* .accept = */ nullptr,
|
|
1770
|
+
/* .apply = */ llama_sampler_typical_apply,
|
|
1771
|
+
/* .reset = */ nullptr,
|
|
1772
|
+
/* .clone = */ llama_sampler_typical_clone,
|
|
1773
|
+
/* .free = */ llama_sampler_typical_free,
|
|
1774
|
+
/* .backend_init = */ nullptr,
|
|
1775
|
+
/* .backend_accept = */ nullptr,
|
|
1776
|
+
/* .backend_apply = */ nullptr,
|
|
1777
|
+
/* .backend_set_input = */ nullptr,
|
|
1040
1778
|
};
|
|
1041
1779
|
|
|
1042
1780
|
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
1781
|
+
const bool is_empty = (p >= 1.0f);
|
|
1782
|
+
|
|
1783
|
+
if (is_empty) {
|
|
1784
|
+
return llama_sampler_init_empty("?typical");
|
|
1785
|
+
}
|
|
1786
|
+
|
|
1043
1787
|
return llama_sampler_init(
|
|
1044
1788
|
/* .iface = */ &llama_sampler_typical_i,
|
|
1045
1789
|
/* .ctx = */ new llama_sampler_typical {
|
|
@@ -1051,12 +1795,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
|
1051
1795
|
|
|
1052
1796
|
// temp
|
|
1053
1797
|
|
|
1054
|
-
struct llama_sampler_temp {
|
|
1798
|
+
struct llama_sampler_temp : public llama_sampler_backend {
|
|
1055
1799
|
const float temp;
|
|
1056
1800
|
};
|
|
1057
1801
|
|
|
1058
|
-
static const char * llama_sampler_temp_name(const struct llama_sampler *
|
|
1059
|
-
|
|
1802
|
+
static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
|
|
1803
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1804
|
+
return sctx->get_name();
|
|
1060
1805
|
}
|
|
1061
1806
|
|
|
1062
1807
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -1074,19 +1819,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
|
|
1074
1819
|
delete (llama_sampler_temp *) smpl->ctx;
|
|
1075
1820
|
}
|
|
1076
1821
|
|
|
1822
|
+
static void llama_sampler_backend_temp_sampling(
|
|
1823
|
+
struct ggml_context * ctx,
|
|
1824
|
+
struct ggml_cgraph * gf,
|
|
1825
|
+
struct llama_sampler_data * data,
|
|
1826
|
+
float temp) {
|
|
1827
|
+
if (temp <= 0.0f) {
|
|
1828
|
+
// Find the most probable token index.
|
|
1829
|
+
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
|
1830
|
+
ggml_set_name(max_idx, "temp_max_idx");
|
|
1831
|
+
|
|
1832
|
+
if (data->candidates) {
|
|
1833
|
+
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
|
1834
|
+
data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
|
|
1835
|
+
} else {
|
|
1836
|
+
data->candidates = max_idx;
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1840
|
+
data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
1841
|
+
|
|
1842
|
+
return;
|
|
1843
|
+
}
|
|
1844
|
+
|
|
1845
|
+
data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
|
|
1846
|
+
|
|
1847
|
+
GGML_UNUSED(gf);
|
|
1848
|
+
}
|
|
1849
|
+
|
|
1850
|
+
static bool llama_sampler_temp_backend_init(
|
|
1851
|
+
struct llama_sampler * smpl,
|
|
1852
|
+
ggml_backend_buffer_type_t buft) {
|
|
1853
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1854
|
+
|
|
1855
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1856
|
+
|
|
1857
|
+
sctx->init(res);
|
|
1858
|
+
|
|
1859
|
+
return res;
|
|
1860
|
+
}
|
|
1861
|
+
|
|
1862
|
+
static void llama_sampler_temp_backend_apply(
|
|
1863
|
+
struct llama_sampler * smpl,
|
|
1864
|
+
struct ggml_context * ctx,
|
|
1865
|
+
struct ggml_cgraph * gf,
|
|
1866
|
+
struct llama_sampler_data * data) {
|
|
1867
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1868
|
+
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
|
1869
|
+
}
|
|
1870
|
+
|
|
1077
1871
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
|
1078
|
-
/* .name
|
|
1079
|
-
/* .accept
|
|
1080
|
-
/* .apply
|
|
1081
|
-
/* .reset
|
|
1082
|
-
/* .clone
|
|
1083
|
-
/* .free
|
|
1872
|
+
/* .name = */ llama_sampler_temp_name,
|
|
1873
|
+
/* .accept = */ nullptr,
|
|
1874
|
+
/* .apply = */ llama_sampler_temp_apply,
|
|
1875
|
+
/* .reset = */ nullptr,
|
|
1876
|
+
/* .clone = */ llama_sampler_temp_clone,
|
|
1877
|
+
/* .free = */ llama_sampler_temp_free,
|
|
1878
|
+
/* .backend_init = */ llama_sampler_temp_backend_init,
|
|
1879
|
+
/* .backend_accept = */ nullptr,
|
|
1880
|
+
/* .backend_apply = */ llama_sampler_temp_backend_apply,
|
|
1881
|
+
/* .backend_set_input = */ nullptr,
|
|
1084
1882
|
};
|
|
1085
1883
|
|
|
1086
1884
|
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
1885
|
+
const bool is_empty = temp == 1.0f;
|
|
1886
|
+
|
|
1887
|
+
if (is_empty) {
|
|
1888
|
+
return llama_sampler_init_empty("?temp");
|
|
1889
|
+
}
|
|
1890
|
+
|
|
1087
1891
|
return llama_sampler_init(
|
|
1088
1892
|
/* .iface = */ &llama_sampler_temp_i,
|
|
1089
1893
|
/* .ctx = */ new llama_sampler_temp {
|
|
1894
|
+
("temp"),
|
|
1090
1895
|
/*.temp = */ temp,
|
|
1091
1896
|
}
|
|
1092
1897
|
);
|
|
@@ -1094,14 +1899,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
|
1094
1899
|
|
|
1095
1900
|
// temp-ext
|
|
1096
1901
|
|
|
1097
|
-
struct llama_sampler_temp_ext {
|
|
1902
|
+
struct llama_sampler_temp_ext : public llama_sampler_backend {
|
|
1098
1903
|
const float temp;
|
|
1099
1904
|
const float delta;
|
|
1100
1905
|
const float exponent;
|
|
1101
1906
|
};
|
|
1102
1907
|
|
|
1103
|
-
static const char * llama_sampler_temp_ext_name(const struct llama_sampler *
|
|
1104
|
-
|
|
1908
|
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
|
|
1909
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
1910
|
+
return sctx->get_name();
|
|
1105
1911
|
}
|
|
1106
1912
|
|
|
1107
1913
|
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -1184,24 +1990,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|
|
1184
1990
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
|
1185
1991
|
}
|
|
1186
1992
|
|
|
1993
|
+
static bool llama_sampler_temp_ext_backend_init(
|
|
1994
|
+
struct llama_sampler * smpl,
|
|
1995
|
+
ggml_backend_buffer_type_t buft) {
|
|
1996
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
1997
|
+
|
|
1998
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1999
|
+
|
|
2000
|
+
sctx->init(res);
|
|
2001
|
+
|
|
2002
|
+
return res;
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
static void llama_sampler_temp_ext_backend_apply(
|
|
2006
|
+
struct llama_sampler * smpl,
|
|
2007
|
+
struct ggml_context * ctx,
|
|
2008
|
+
struct ggml_cgraph * gf,
|
|
2009
|
+
struct llama_sampler_data * data) {
|
|
2010
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
2011
|
+
|
|
2012
|
+
// Revert to standard temperature scaling if delta or temp are non-positive.
|
|
2013
|
+
if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
|
|
2014
|
+
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
|
2015
|
+
return;
|
|
2016
|
+
}
|
|
2017
|
+
|
|
2018
|
+
// Calculate min_temp, max_temp, and max_entropy.
|
|
2019
|
+
const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
|
|
2020
|
+
const float max_temp = sctx->temp + sctx->delta;
|
|
2021
|
+
const float max_entropy = logf(data->logits->ne[0]);
|
|
2022
|
+
|
|
2023
|
+
// Calculate the probabilities.
|
|
2024
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
2025
|
+
ggml_set_name(probs, "temp_ext_softmax_probs");
|
|
2026
|
+
|
|
2027
|
+
// Clamp probabilities to avoid log(0) which would give -inf
|
|
2028
|
+
struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
|
|
2029
|
+
ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
|
|
2030
|
+
|
|
2031
|
+
// Calculate the entropy, entropy = -Σ(p * log(p)).
|
|
2032
|
+
struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
|
|
2033
|
+
struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
|
|
2034
|
+
struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
|
|
2035
|
+
struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
|
|
2036
|
+
ggml_set_name(log_probs, "temp_ext_log_probs");
|
|
2037
|
+
ggml_set_name(p_log_p, "temp_ext_p_log_p");
|
|
2038
|
+
ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
|
|
2039
|
+
ggml_set_name(entropy, "temp_ext_entropy");
|
|
2040
|
+
|
|
2041
|
+
// Normalize the entropy, norm_entropy = entropy / max_entropy
|
|
2042
|
+
struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
|
|
2043
|
+
ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
|
|
2044
|
+
|
|
2045
|
+
// Calculate the dynamic temperature:
|
|
2046
|
+
// dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
|
|
2047
|
+
//
|
|
2048
|
+
// Calculate powf(normalized_entropy, exponent) as
|
|
2049
|
+
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
|
|
2050
|
+
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
|
|
2051
|
+
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
|
|
2052
|
+
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
|
|
2053
|
+
// With pow_entropy computed we can now compute dyn_temp, scaling by
|
|
2054
|
+
// (max_temp - min_temp) and then adding min_temp.
|
|
2055
|
+
struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
|
|
2056
|
+
ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
|
|
2057
|
+
ggml_set_name(scaled_log, "temp_ext_scaled_log");
|
|
2058
|
+
ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
|
|
2059
|
+
ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
|
|
2060
|
+
|
|
2061
|
+
// Scale the logits by the dynamic temperature
|
|
2062
|
+
struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
|
|
2063
|
+
ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
|
|
2064
|
+
|
|
2065
|
+
data->logits = scaled_logits;
|
|
2066
|
+
}
|
|
2067
|
+
|
|
1187
2068
|
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|
1188
|
-
/* .name
|
|
1189
|
-
/* .accept
|
|
1190
|
-
/* .apply
|
|
1191
|
-
/* .reset
|
|
1192
|
-
/* .clone
|
|
1193
|
-
/* .free
|
|
2069
|
+
/* .name = */ llama_sampler_temp_ext_name,
|
|
2070
|
+
/* .accept = */ nullptr,
|
|
2071
|
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
|
2072
|
+
/* .reset = */ nullptr,
|
|
2073
|
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
|
2074
|
+
/* .free = */ llama_sampler_temp_ext_free,
|
|
2075
|
+
/* .backend_init = */ llama_sampler_temp_ext_backend_init,
|
|
2076
|
+
/* .backend_accept = */ nullptr,
|
|
2077
|
+
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
|
|
2078
|
+
/* .backend_set_input = */ nullptr,
|
|
1194
2079
|
};
|
|
1195
2080
|
|
|
1196
2081
|
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
|
1197
|
-
|
|
2082
|
+
const bool is_empty = temp == 1.0f && delta <= 0.0f;
|
|
2083
|
+
|
|
2084
|
+
if (is_empty) {
|
|
2085
|
+
return llama_sampler_init_empty("?temp-ext");
|
|
2086
|
+
}
|
|
2087
|
+
|
|
2088
|
+
auto * res = llama_sampler_init(
|
|
1198
2089
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
|
1199
2090
|
/* .ctx = */ new llama_sampler_temp_ext {
|
|
2091
|
+
("temp-ext"),
|
|
1200
2092
|
/* .temp = */ temp,
|
|
1201
2093
|
/* .delta = */ delta,
|
|
1202
2094
|
/* .exponent = */ exponent,
|
|
1203
2095
|
}
|
|
1204
2096
|
);
|
|
2097
|
+
|
|
2098
|
+
return res;
|
|
1205
2099
|
}
|
|
1206
2100
|
|
|
1207
2101
|
// xtc
|
|
@@ -1214,7 +2108,7 @@ struct llama_sampler_xtc {
|
|
|
1214
2108
|
const uint32_t seed;
|
|
1215
2109
|
uint32_t seed_cur;
|
|
1216
2110
|
|
|
1217
|
-
std::mt19937
|
|
2111
|
+
std::mt19937 rng;
|
|
1218
2112
|
};
|
|
1219
2113
|
|
|
1220
2114
|
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
|
@@ -1279,16 +2173,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
|
|
1279
2173
|
}
|
|
1280
2174
|
|
|
1281
2175
|
static struct llama_sampler_i llama_sampler_xtc_i = {
|
|
1282
|
-
/* .name
|
|
1283
|
-
/* .accept
|
|
1284
|
-
/* .apply
|
|
1285
|
-
/* .reset
|
|
1286
|
-
/* .clone
|
|
1287
|
-
/* .free
|
|
2176
|
+
/* .name = */ llama_sampler_xtc_name,
|
|
2177
|
+
/* .accept = */ nullptr,
|
|
2178
|
+
/* .apply = */ llama_sample_xtc_apply,
|
|
2179
|
+
/* .reset = */ llama_sampler_xtc_reset,
|
|
2180
|
+
/* .clone = */ llama_sampler_xtc_clone,
|
|
2181
|
+
/* .free = */ llama_sampler_xtc_free,
|
|
2182
|
+
/* .backend_init = */ nullptr,
|
|
2183
|
+
/* .backend_accept = */ nullptr,
|
|
2184
|
+
/* .backend_apply = */ nullptr,
|
|
2185
|
+
/* .backend_set_input = */ nullptr,
|
|
1288
2186
|
};
|
|
1289
2187
|
|
|
1290
2188
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
|
1291
|
-
|
|
2189
|
+
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
|
2190
|
+
|
|
2191
|
+
if (is_empty) {
|
|
2192
|
+
return llama_sampler_init_empty("?xtc");
|
|
2193
|
+
}
|
|
2194
|
+
|
|
2195
|
+
const auto seed_cur = get_rng_seed(seed);
|
|
2196
|
+
|
|
1292
2197
|
return llama_sampler_init(
|
|
1293
2198
|
/* .iface = */ &llama_sampler_xtc_i,
|
|
1294
2199
|
/* .ctx = */ new llama_sampler_xtc {
|
|
@@ -1387,16 +2292,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|
|
1387
2292
|
}
|
|
1388
2293
|
|
|
1389
2294
|
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|
1390
|
-
/* .name
|
|
1391
|
-
/* .accept
|
|
1392
|
-
/* .apply
|
|
1393
|
-
/* .reset
|
|
1394
|
-
/* .clone
|
|
1395
|
-
/* .free
|
|
2295
|
+
/* .name = */ llama_sampler_mirostat_name,
|
|
2296
|
+
/* .accept = */ nullptr,
|
|
2297
|
+
/* .apply = */ llama_sampler_mirostat_apply,
|
|
2298
|
+
/* .reset = */ llama_sampler_mirostat_reset,
|
|
2299
|
+
/* .clone = */ llama_sampler_mirostat_clone,
|
|
2300
|
+
/* .free = */ llama_sampler_mirostat_free,
|
|
2301
|
+
/* .backend_init = */ nullptr,
|
|
2302
|
+
/* .backend_accept = */ nullptr,
|
|
2303
|
+
/* .backend_apply = */ nullptr,
|
|
2304
|
+
/* .backend_set_input = */ nullptr,
|
|
1396
2305
|
};
|
|
1397
2306
|
|
|
1398
2307
|
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
|
1399
|
-
auto seed_cur = get_rng_seed(seed);
|
|
2308
|
+
const auto seed_cur = get_rng_seed(seed);
|
|
2309
|
+
|
|
1400
2310
|
return llama_sampler_init(
|
|
1401
2311
|
/* .iface = */ &llama_sampler_mirostat_i,
|
|
1402
2312
|
/* .ctx = */ new llama_sampler_mirostat {
|
|
@@ -1486,12 +2396,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
|
|
1486
2396
|
}
|
|
1487
2397
|
|
|
1488
2398
|
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|
1489
|
-
/* .name
|
|
1490
|
-
/* .accept
|
|
1491
|
-
/* .apply
|
|
1492
|
-
/* .reset
|
|
1493
|
-
/* .clone
|
|
1494
|
-
/* .free
|
|
2399
|
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
|
2400
|
+
/* .accept = */ nullptr,
|
|
2401
|
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
|
2402
|
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
|
2403
|
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
|
2404
|
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
|
2405
|
+
/* .backend_init = */ nullptr,
|
|
2406
|
+
/* .backend_accept = */ nullptr,
|
|
2407
|
+
/* .backend_apply = */ nullptr,
|
|
2408
|
+
/* .backend_set_input = */ nullptr,
|
|
1495
2409
|
};
|
|
1496
2410
|
|
|
1497
2411
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
|
@@ -1603,12 +2517,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
|
|
1603
2517
|
}
|
|
1604
2518
|
|
|
1605
2519
|
static struct llama_sampler_i llama_sampler_grammar_i = {
|
|
1606
|
-
/* .name
|
|
1607
|
-
/* .accept
|
|
1608
|
-
/* .apply
|
|
1609
|
-
/* .reset
|
|
1610
|
-
/* .clone
|
|
1611
|
-
/* .free
|
|
2520
|
+
/* .name = */ llama_sampler_grammar_name,
|
|
2521
|
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
|
2522
|
+
/* .apply = */ llama_sampler_grammar_apply,
|
|
2523
|
+
/* .reset = */ llama_sampler_grammar_reset,
|
|
2524
|
+
/* .clone = */ llama_sampler_grammar_clone,
|
|
2525
|
+
/* .free = */ llama_sampler_grammar_free,
|
|
2526
|
+
/* .backend_init = */ nullptr,
|
|
2527
|
+
/* .backend_accept = */ nullptr,
|
|
2528
|
+
/* .backend_apply = */ nullptr,
|
|
2529
|
+
/* .backend_set_input = */ nullptr,
|
|
1612
2530
|
};
|
|
1613
2531
|
|
|
1614
2532
|
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
@@ -1625,10 +2543,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
1625
2543
|
auto * ctx = new llama_sampler_grammar;
|
|
1626
2544
|
|
|
1627
2545
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
2546
|
+
std::string trigger_pattern;
|
|
2547
|
+
llama_grammar * grammar = nullptr;
|
|
1628
2548
|
// TODO: remove trigger_words support.
|
|
1629
2549
|
if (trigger_words != nullptr && num_trigger_words > 0) {
|
|
1630
2550
|
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
|
1631
|
-
|
|
2551
|
+
trigger_pattern = "[\\s\\S]*?(";
|
|
1632
2552
|
for (size_t i = 0; i < num_trigger_words; ++i) {
|
|
1633
2553
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
|
1634
2554
|
if (i > 0) {
|
|
@@ -1637,15 +2557,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
1637
2557
|
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
|
1638
2558
|
}
|
|
1639
2559
|
trigger_pattern += ")[\\s\\S]*";
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
2560
|
+
|
|
2561
|
+
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
|
2562
|
+
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
|
2563
|
+
} else {
|
|
2564
|
+
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
|
1643
2565
|
}
|
|
1644
2566
|
*ctx = {
|
|
1645
2567
|
/* .vocab = */ vocab,
|
|
1646
2568
|
/* .grammar_str = */ grammar_str,
|
|
1647
2569
|
/* .grammar_root = */ grammar_root,
|
|
1648
|
-
/* .grammar = */
|
|
2570
|
+
/* .grammar = */ grammar,
|
|
1649
2571
|
};
|
|
1650
2572
|
if (!ctx->grammar) {
|
|
1651
2573
|
delete ctx;
|
|
@@ -1806,12 +2728,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
|
|
1806
2728
|
}
|
|
1807
2729
|
|
|
1808
2730
|
static struct llama_sampler_i llama_sampler_penalties_i = {
|
|
1809
|
-
/* .name
|
|
1810
|
-
/* .accept
|
|
1811
|
-
/* .apply
|
|
1812
|
-
/* .reset
|
|
1813
|
-
/* .clone
|
|
1814
|
-
/* .free
|
|
2731
|
+
/* .name = */ llama_sampler_penalties_name,
|
|
2732
|
+
/* .accept = */ llama_sampler_penalties_accept,
|
|
2733
|
+
/* .apply = */ llama_sampler_penalties_apply,
|
|
2734
|
+
/* .reset = */ llama_sampler_penalties_reset,
|
|
2735
|
+
/* .clone = */ llama_sampler_penalties_clone,
|
|
2736
|
+
/* .free = */ llama_sampler_penalties_free,
|
|
2737
|
+
/* .backend_init = */ nullptr,
|
|
2738
|
+
/* .backend_accept = */ nullptr,
|
|
2739
|
+
/* .backend_apply = */ nullptr,
|
|
2740
|
+
/* .backend_set_input = */ nullptr,
|
|
1815
2741
|
};
|
|
1816
2742
|
|
|
1817
2743
|
struct llama_sampler * llama_sampler_init_penalties(
|
|
@@ -1821,6 +2747,12 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
|
1821
2747
|
float penalty_present) {
|
|
1822
2748
|
penalty_last_n = std::max(penalty_last_n, 0);
|
|
1823
2749
|
|
|
2750
|
+
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
|
2751
|
+
|
|
2752
|
+
if (is_empty) {
|
|
2753
|
+
return llama_sampler_init_empty("?penalties");
|
|
2754
|
+
}
|
|
2755
|
+
|
|
1824
2756
|
return llama_sampler_init(
|
|
1825
2757
|
/* .iface = */ &llama_sampler_penalties_i,
|
|
1826
2758
|
/* .ctx = */ new llama_sampler_penalties {
|
|
@@ -1858,9 +2790,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|
|
1858
2790
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1859
2791
|
// Only count non-negative infinity values
|
|
1860
2792
|
if (cur_p->data[i].logit != -INFINITY) {
|
|
1861
|
-
|
|
1862
|
-
max = cur_p->data[i].logit;
|
|
1863
|
-
}
|
|
2793
|
+
max = std::max(max, cur_p->data[i].logit);
|
|
1864
2794
|
logits_sum += cur_p->data[i].logit;
|
|
1865
2795
|
valid_count++;
|
|
1866
2796
|
}
|
|
@@ -1897,15 +2827,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
|
|
1897
2827
|
}
|
|
1898
2828
|
|
|
1899
2829
|
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
|
1900
|
-
/* .name
|
|
1901
|
-
/* .accept
|
|
1902
|
-
/* .apply
|
|
1903
|
-
/* .reset
|
|
1904
|
-
/* .clone
|
|
1905
|
-
/* .free
|
|
2830
|
+
/* .name = */ llama_sampler_top_n_sigma_name,
|
|
2831
|
+
/* .accept = */ nullptr,
|
|
2832
|
+
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
|
2833
|
+
/* .reset = */ nullptr,
|
|
2834
|
+
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
|
2835
|
+
/* .free = */ llama_sampler_top_n_sigma_free,
|
|
2836
|
+
/* .backend_init = */ nullptr,
|
|
2837
|
+
/* .backend_accept = */ nullptr,
|
|
2838
|
+
/* .backend_apply = */ nullptr,
|
|
2839
|
+
/* .backend_set_input = */ nullptr,
|
|
1906
2840
|
};
|
|
1907
2841
|
|
|
1908
2842
|
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
|
2843
|
+
const bool is_empty = (n <= 0.0f);
|
|
2844
|
+
|
|
2845
|
+
if (is_empty) {
|
|
2846
|
+
return llama_sampler_init_empty("?top-n-sigma");
|
|
2847
|
+
}
|
|
2848
|
+
|
|
1909
2849
|
return llama_sampler_init(
|
|
1910
2850
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
|
1911
2851
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
|
@@ -2227,12 +3167,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
|
|
2227
3167
|
}
|
|
2228
3168
|
|
|
2229
3169
|
static struct llama_sampler_i llama_sampler_dry_i = {
|
|
2230
|
-
/* .name
|
|
2231
|
-
/* .accept
|
|
2232
|
-
/* .apply
|
|
2233
|
-
/* .reset
|
|
2234
|
-
/* .clone
|
|
2235
|
-
/* .free
|
|
3170
|
+
/* .name = */ llama_sampler_dry_name,
|
|
3171
|
+
/* .accept = */ llama_sampler_dry_accept,
|
|
3172
|
+
/* .apply = */ llama_sampler_dry_apply,
|
|
3173
|
+
/* .reset = */ llama_sampler_dry_reset,
|
|
3174
|
+
/* .clone = */ llama_sampler_dry_clone,
|
|
3175
|
+
/* .free = */ llama_sampler_dry_free,
|
|
3176
|
+
/* .backend_init = */ nullptr,
|
|
3177
|
+
/* .backend_accept = */ nullptr,
|
|
3178
|
+
/* .backend_apply = */ nullptr,
|
|
3179
|
+
/* .backend_set_input = */ nullptr,
|
|
2236
3180
|
};
|
|
2237
3181
|
|
|
2238
3182
|
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
|
@@ -2243,6 +3187,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|
|
2243
3187
|
|
|
2244
3188
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
|
2245
3189
|
|
|
3190
|
+
if (!dry_enabled) {
|
|
3191
|
+
return llama_sampler_init_empty("?dry");
|
|
3192
|
+
}
|
|
3193
|
+
|
|
2246
3194
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
|
2247
3195
|
// Process sequence breakers
|
|
2248
3196
|
for (size_t i = 0; i < num_breakers; ++i) {
|
|
@@ -2311,18 +3259,186 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|
|
2311
3259
|
return result;
|
|
2312
3260
|
}
|
|
2313
3261
|
|
|
3262
|
+
// adaptive-p sampler state
|
|
3263
|
+
//
|
|
3264
|
+
// maintains an exponential moving average of the *ORIGINAL* probabilities
|
|
3265
|
+
// of selected tokens, used to compute an adapted target at each sampling step.
|
|
3266
|
+
//
|
|
3267
|
+
// see llama.h for a full description of the sampler
|
|
3268
|
+
//
|
|
3269
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/17927
|
|
3270
|
+
//
|
|
3271
|
+
struct llama_sampler_adaptive_p {
|
|
3272
|
+
const float target; // target probability (0.0 - 1.0; negative = disabled)
|
|
3273
|
+
const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
|
|
3274
|
+
const uint32_t seed; // original RNG seed
|
|
3275
|
+
uint32_t seed_cur; // actual RNG seed
|
|
3276
|
+
std::mt19937 rng; // RNG state
|
|
3277
|
+
float weighted_sum; // sum(p_i * decay^i)
|
|
3278
|
+
float total_weight; // sum(decay^i), converges to 1/(1-decay)
|
|
3279
|
+
std::vector<float> original_probs; // pre-transform probs, cached for EMA update
|
|
3280
|
+
llama_token pending_token_id; // token ID of selected token
|
|
3281
|
+
int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs
|
|
3282
|
+
};
|
|
3283
|
+
|
|
3284
|
+
// adaptive probability transformation constants
|
|
3285
|
+
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
|
|
3286
|
+
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
|
|
3287
|
+
static constexpr float SHARPNESS = 10.0f;
|
|
3288
|
+
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
|
|
3289
|
+
|
|
3290
|
+
static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) {
|
|
3291
|
+
return "adaptive-p";
|
|
3292
|
+
}
|
|
3293
|
+
|
|
3294
|
+
static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
3295
|
+
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
|
3296
|
+
|
|
3297
|
+
llama_sampler_softmax_impl(cur_p, false);
|
|
3298
|
+
|
|
3299
|
+
if (ctx->target < 0.0f) {
|
|
3300
|
+
// at negative target values, adaptive-p is no-op
|
|
3301
|
+
// we simply sample from the existing distribution
|
|
3302
|
+
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
|
3303
|
+
return;
|
|
3304
|
+
}
|
|
3305
|
+
|
|
3306
|
+
// store the original probabilities
|
|
3307
|
+
ctx->original_probs.resize(cur_p->size);
|
|
3308
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
3309
|
+
ctx->original_probs[i] = cur_p->data[i].p;
|
|
3310
|
+
}
|
|
3311
|
+
|
|
3312
|
+
// using the EMA, compute the adapted target probability for the current sampling step
|
|
3313
|
+
auto target = std::clamp(ctx->target, 0.0f, 1.0f);
|
|
3314
|
+
float adapted_target = std::clamp(
|
|
3315
|
+
ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight),
|
|
3316
|
+
0.0f, 1.0f
|
|
3317
|
+
);
|
|
3318
|
+
|
|
3319
|
+
// adaptive probability transform
|
|
3320
|
+
//
|
|
3321
|
+
// quadratic near target for fine differentiation, transitioning to linear decay in the
|
|
3322
|
+
// tails. unbounded negative logits ensure proper suppression of far-from-target tokens
|
|
3323
|
+
// after the softmax.
|
|
3324
|
+
//
|
|
3325
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
3326
|
+
if (cur_p->data[i].logit == -INFINITY) {
|
|
3327
|
+
// don't transform logits that are -INFINITY
|
|
3328
|
+
// (as masked out by e.g. min-p and top-p when using backend sampling)
|
|
3329
|
+
continue;
|
|
3330
|
+
}
|
|
3331
|
+
float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH);
|
|
3332
|
+
cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist);
|
|
3333
|
+
}
|
|
3334
|
+
|
|
3335
|
+
// softmax and sample from the transformed distribution
|
|
3336
|
+
llama_sampler_softmax_impl(cur_p, false);
|
|
3337
|
+
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
|
3338
|
+
cur_p->selected = idx;
|
|
3339
|
+
|
|
3340
|
+
// store the selected token ID for acceptance later
|
|
3341
|
+
ctx->pending_token_id = cur_p->data[idx].id;
|
|
3342
|
+
ctx->pending_token_idx = idx;
|
|
3343
|
+
}
|
|
3344
|
+
|
|
3345
|
+
static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) {
|
|
3346
|
+
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
|
3347
|
+
if (ctx->pending_token_id == token) {
|
|
3348
|
+
GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL);
|
|
3349
|
+
GGML_ASSERT(ctx->pending_token_idx != -1);
|
|
3350
|
+
// update EMA with the original probability of the selected token
|
|
3351
|
+
ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum;
|
|
3352
|
+
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
|
3353
|
+
}
|
|
3354
|
+
ctx->pending_token_id = LLAMA_TOKEN_NULL;
|
|
3355
|
+
ctx->pending_token_idx = -1;
|
|
3356
|
+
}
|
|
3357
|
+
|
|
3358
|
+
static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
|
|
3359
|
+
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
|
|
3360
|
+
// ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
|
|
3361
|
+
// original_probs is completely overwritten on every call to _apply.
|
|
3362
|
+
// so we only need to reset the EMA state and pending token.
|
|
3363
|
+
ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
|
|
3364
|
+
ctx->total_weight = 1.0f / (1.0f - ctx->decay);
|
|
3365
|
+
ctx->pending_token_id = LLAMA_TOKEN_NULL;
|
|
3366
|
+
ctx->pending_token_idx = -1;
|
|
3367
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
3368
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
3369
|
+
}
|
|
3370
|
+
|
|
3371
|
+
static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
|
|
3372
|
+
const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx;
|
|
3373
|
+
auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
|
|
3374
|
+
auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
|
|
3375
|
+
|
|
3376
|
+
// copy everything (target, decay, seed, and RNG are already set)
|
|
3377
|
+
result_ctx->weighted_sum = ctx->weighted_sum;
|
|
3378
|
+
result_ctx->total_weight = ctx->total_weight;
|
|
3379
|
+
result_ctx->pending_token_id = ctx->pending_token_id;
|
|
3380
|
+
result_ctx->pending_token_idx = ctx->pending_token_idx;
|
|
3381
|
+
|
|
3382
|
+
return result;
|
|
3383
|
+
}
|
|
3384
|
+
|
|
3385
|
+
static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) {
|
|
3386
|
+
delete (llama_sampler_adaptive_p *) smpl->ctx;
|
|
3387
|
+
}
|
|
3388
|
+
|
|
3389
|
+
static struct llama_sampler_i llama_sampler_adaptive_p_i = {
|
|
3390
|
+
/* .name = */ llama_sampler_adaptive_p_name,
|
|
3391
|
+
/* .accept = */ llama_sampler_adaptive_p_accept,
|
|
3392
|
+
/* .apply = */ llama_sampler_adaptive_p_apply,
|
|
3393
|
+
/* .reset = */ llama_sampler_adaptive_p_reset,
|
|
3394
|
+
/* .clone = */ llama_sampler_adaptive_p_clone,
|
|
3395
|
+
/* .free = */ llama_sampler_adaptive_p_free,
|
|
3396
|
+
/* .backend_init = */ nullptr,
|
|
3397
|
+
/* .backend_accept = */ nullptr,
|
|
3398
|
+
/* .backend_apply = */ nullptr,
|
|
3399
|
+
/* .backend_set_input = */ nullptr,
|
|
3400
|
+
};
|
|
3401
|
+
|
|
3402
|
+
struct llama_sampler * llama_sampler_init_adaptive_p(
|
|
3403
|
+
float target,
|
|
3404
|
+
float decay,
|
|
3405
|
+
uint32_t seed
|
|
3406
|
+
) {
|
|
3407
|
+
auto seed_cur = get_rng_seed(seed);
|
|
3408
|
+
float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
|
3409
|
+
return llama_sampler_init(
|
|
3410
|
+
/* .iface = */ &llama_sampler_adaptive_p_i,
|
|
3411
|
+
/* .ctx = */ new llama_sampler_adaptive_p {
|
|
3412
|
+
/* .target = */ target,
|
|
3413
|
+
/* .decay = */ clamped_decay,
|
|
3414
|
+
/* .seed = */ seed,
|
|
3415
|
+
/* .seed_cur = */ seed_cur,
|
|
3416
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
3417
|
+
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
|
3418
|
+
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
|
3419
|
+
/* .original_probs = */ {},
|
|
3420
|
+
/* .pending_token_id = */ LLAMA_TOKEN_NULL,
|
|
3421
|
+
/* .pending_token_idx = */ -1
|
|
3422
|
+
}
|
|
3423
|
+
);
|
|
3424
|
+
}
|
|
3425
|
+
|
|
2314
3426
|
// logit-bias
|
|
2315
3427
|
|
|
2316
|
-
struct llama_sampler_logit_bias {
|
|
3428
|
+
struct llama_sampler_logit_bias : public llama_sampler_backend {
|
|
2317
3429
|
const int32_t n_vocab;
|
|
2318
3430
|
|
|
2319
3431
|
const std::vector<llama_logit_bias> logit_bias;
|
|
2320
3432
|
|
|
2321
3433
|
std::vector<llama_logit_bias> to_search;
|
|
3434
|
+
|
|
3435
|
+
struct ggml_tensor * inp_logit_bias;
|
|
3436
|
+
struct ggml_tensor * inp_logit_idxs;
|
|
2322
3437
|
};
|
|
2323
3438
|
|
|
2324
|
-
static const char * llama_sampler_logit_bias_name(const struct llama_sampler *
|
|
2325
|
-
|
|
3439
|
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
|
|
3440
|
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3441
|
+
return ctx->get_name();
|
|
2326
3442
|
}
|
|
2327
3443
|
|
|
2328
3444
|
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -2367,25 +3483,110 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
|
|
2367
3483
|
delete (llama_sampler_logit_bias *) smpl->ctx;
|
|
2368
3484
|
}
|
|
2369
3485
|
|
|
3486
|
+
static void llama_sampler_logit_bias_backend_apply(
|
|
3487
|
+
struct llama_sampler * smpl,
|
|
3488
|
+
struct ggml_context * ctx,
|
|
3489
|
+
struct ggml_cgraph * gf,
|
|
3490
|
+
struct llama_sampler_data * data) {
|
|
3491
|
+
GGML_UNUSED(gf);
|
|
3492
|
+
GGML_UNUSED(ctx);
|
|
3493
|
+
|
|
3494
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3495
|
+
if (sctx->logit_bias.empty()) {
|
|
3496
|
+
return;
|
|
3497
|
+
}
|
|
3498
|
+
|
|
3499
|
+
const size_t n = sctx->logit_bias.size();
|
|
3500
|
+
|
|
3501
|
+
sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
|
|
3502
|
+
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
|
|
3503
|
+
ggml_set_input(sctx->inp_logit_bias);
|
|
3504
|
+
|
|
3505
|
+
sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
|
|
3506
|
+
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
|
|
3507
|
+
ggml_set_input(sctx->inp_logit_idxs);
|
|
3508
|
+
|
|
3509
|
+
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
|
|
3510
|
+
|
|
3511
|
+
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
|
|
3512
|
+
cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
|
|
3513
|
+
cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
|
|
3514
|
+
|
|
3515
|
+
data->logits = ggml_add(ctx, data->logits, cur);
|
|
3516
|
+
}
|
|
3517
|
+
|
|
3518
|
+
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
|
|
3519
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3520
|
+
if (sctx->logit_bias.empty()) {
|
|
3521
|
+
return;
|
|
3522
|
+
}
|
|
3523
|
+
|
|
3524
|
+
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
|
3525
|
+
GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
|
|
3526
|
+
|
|
3527
|
+
const size_t n = sctx->logit_bias.size();
|
|
3528
|
+
|
|
3529
|
+
std::vector<float> data_logit_bias(n, 0.0f);
|
|
3530
|
+
std::vector<int32_t> data_logit_idxs(n, 0);
|
|
3531
|
+
for (size_t i = 0; i < n; ++i) {
|
|
3532
|
+
const auto & lb = sctx->logit_bias[i];
|
|
3533
|
+
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
|
3534
|
+
data_logit_bias[i] = lb.bias;
|
|
3535
|
+
data_logit_idxs[i] = lb.token;
|
|
3536
|
+
}
|
|
3537
|
+
|
|
3538
|
+
ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
|
3539
|
+
ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
|
|
3540
|
+
}
|
|
3541
|
+
|
|
3542
|
+
static bool llama_sampler_logit_bias_backend_init(
|
|
3543
|
+
struct llama_sampler * smpl,
|
|
3544
|
+
ggml_backend_buffer_type_t buft) {
|
|
3545
|
+
GGML_UNUSED(buft);
|
|
3546
|
+
|
|
3547
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3548
|
+
|
|
3549
|
+
sctx->init(true);
|
|
3550
|
+
|
|
3551
|
+
if (sctx->logit_bias.empty()) {
|
|
3552
|
+
return true;
|
|
3553
|
+
}
|
|
3554
|
+
|
|
3555
|
+
return true;
|
|
3556
|
+
}
|
|
3557
|
+
|
|
2370
3558
|
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
|
2371
|
-
/* .name
|
|
2372
|
-
/* .accept
|
|
2373
|
-
/* .apply
|
|
2374
|
-
/* .reset
|
|
2375
|
-
/* .clone
|
|
2376
|
-
/* .free
|
|
3559
|
+
/* .name = */ llama_sampler_logit_bias_name,
|
|
3560
|
+
/* .accept = */ nullptr,
|
|
3561
|
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
|
3562
|
+
/* .reset = */ nullptr,
|
|
3563
|
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
|
3564
|
+
/* .free = */ llama_sampler_logit_bias_free,
|
|
3565
|
+
/* .backend_init = */ llama_sampler_logit_bias_backend_init,
|
|
3566
|
+
/* .backend_accept = */ nullptr,
|
|
3567
|
+
/* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
|
|
3568
|
+
/* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
|
|
2377
3569
|
};
|
|
2378
3570
|
|
|
2379
3571
|
struct llama_sampler * llama_sampler_init_logit_bias(
|
|
2380
3572
|
int32_t n_vocab,
|
|
2381
3573
|
int32_t n_logit_bias,
|
|
2382
3574
|
const llama_logit_bias * logit_bias) {
|
|
3575
|
+
const bool is_empty = n_logit_bias <= 0;
|
|
3576
|
+
|
|
3577
|
+
if (is_empty) {
|
|
3578
|
+
return llama_sampler_init_empty("?logit-bias");
|
|
3579
|
+
}
|
|
3580
|
+
|
|
2383
3581
|
return llama_sampler_init(
|
|
2384
3582
|
/* .iface = */ &llama_sampler_logit_bias_i,
|
|
2385
3583
|
/* .ctx = */ new llama_sampler_logit_bias {
|
|
2386
|
-
|
|
2387
|
-
/* .
|
|
2388
|
-
/* .
|
|
3584
|
+
("logit-bias"),
|
|
3585
|
+
/* .n_vocab = */ n_vocab,
|
|
3586
|
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
|
3587
|
+
/* .to_search = */ {},
|
|
3588
|
+
/* .inp_logit_bias = */ nullptr,
|
|
3589
|
+
/* .inp_logit_idxs = */ nullptr,
|
|
2389
3590
|
}
|
|
2390
3591
|
);
|
|
2391
3592
|
}
|
|
@@ -2541,8 +3742,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2541
3742
|
if (n_non_eog == 0) {
|
|
2542
3743
|
cur_p->size = 1;
|
|
2543
3744
|
cur_p->data[0].id = ctx->vocab->token_eot();
|
|
3745
|
+
if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
|
|
3746
|
+
cur_p->data[0].id = ctx->vocab->token_eos();
|
|
3747
|
+
}
|
|
2544
3748
|
cur_p->data[0].logit = 1.0f;
|
|
2545
3749
|
|
|
3750
|
+
GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
|
|
3751
|
+
|
|
2546
3752
|
return;
|
|
2547
3753
|
}
|
|
2548
3754
|
|
|
@@ -2593,12 +3799,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
|
|
2593
3799
|
}
|
|
2594
3800
|
|
|
2595
3801
|
static struct llama_sampler_i llama_sampler_infill_i = {
|
|
2596
|
-
/* .name
|
|
2597
|
-
/* .accept
|
|
2598
|
-
/* .apply
|
|
2599
|
-
/* .reset
|
|
2600
|
-
/* .clone
|
|
2601
|
-
/* .free
|
|
3802
|
+
/* .name = */ llama_sampler_infill_name,
|
|
3803
|
+
/* .accept = */ nullptr,
|
|
3804
|
+
/* .apply = */ llama_sampler_infill_apply,
|
|
3805
|
+
/* .reset = */ nullptr,
|
|
3806
|
+
/* .clone = */ llama_sampler_infill_clone,
|
|
3807
|
+
/* .free = */ llama_sampler_infill_free,
|
|
3808
|
+
/* .backend_apply = */ nullptr,
|
|
3809
|
+
/* .backend_accept = */ nullptr,
|
|
3810
|
+
/* .backend_set_input = */ nullptr,
|
|
3811
|
+
/* .backend_init = */ nullptr,
|
|
2602
3812
|
};
|
|
2603
3813
|
|
|
2604
3814
|
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
|
@@ -2630,7 +3840,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|
|
2630
3840
|
if (smpl->iface == &llama_sampler_chain_i) {
|
|
2631
3841
|
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
|
2632
3842
|
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
|
2633
|
-
const uint32_t seed = llama_sampler_get_seed(
|
|
3843
|
+
const uint32_t seed = llama_sampler_get_seed(it->ptr);
|
|
2634
3844
|
if (seed != LLAMA_DEFAULT_SEED) {
|
|
2635
3845
|
return seed;
|
|
2636
3846
|
}
|
|
@@ -2660,8 +3870,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
|
|
|
2660
3870
|
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
|
2661
3871
|
const auto data = llama_perf_sampler(chain);
|
|
2662
3872
|
|
|
2663
|
-
LLAMA_LOG_INFO("%s:
|
|
2664
|
-
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
|
3873
|
+
LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
|
|
2665
3874
|
}
|
|
2666
3875
|
|
|
2667
3876
|
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
@@ -2671,5 +3880,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
|
2671
3880
|
|
|
2672
3881
|
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
|
2673
3882
|
|
|
2674
|
-
ctx->t_sample_us =
|
|
3883
|
+
ctx->t_sample_us = 0;
|
|
3884
|
+
ctx->n_sample = 0;
|
|
2675
3885
|
}
|