whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -4,6 +4,9 @@
|
|
|
4
4
|
#include "llama-vocab.h"
|
|
5
5
|
#include "llama-grammar.h"
|
|
6
6
|
|
|
7
|
+
#include "ggml-cpp.h"
|
|
8
|
+
|
|
9
|
+
#include <array>
|
|
7
10
|
#include <algorithm>
|
|
8
11
|
#include <cassert>
|
|
9
12
|
#include <cfloat>
|
|
@@ -128,6 +131,89 @@ struct ring_buffer {
|
|
|
128
131
|
std::vector<T> data;
|
|
129
132
|
};
|
|
130
133
|
|
|
134
|
+
// writes result in res, does not mutate cur
|
|
135
|
+
static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
|
|
136
|
+
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
|
137
|
+
return a.logit > b.logit;
|
|
138
|
+
};
|
|
139
|
+
|
|
140
|
+
constexpr int nbuckets = 128;
|
|
141
|
+
constexpr float bucket_low = -10.0f;
|
|
142
|
+
constexpr float bucket_high = 10.0f;
|
|
143
|
+
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
|
144
|
+
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
|
145
|
+
|
|
146
|
+
std::vector<int> bucket_idx;
|
|
147
|
+
std::vector<int> histo(nbuckets, 0);
|
|
148
|
+
|
|
149
|
+
std::vector<llama_token_data*> bucket_ptrs;
|
|
150
|
+
|
|
151
|
+
bucket_idx.reserve(cur.size);
|
|
152
|
+
|
|
153
|
+
for (int i = 0; i < (int)cur.size; ++i) {
|
|
154
|
+
const float val = cur.data[i].logit;
|
|
155
|
+
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
|
156
|
+
ib = std::max(0, std::min(nbuckets - 1, ib));
|
|
157
|
+
bucket_idx.push_back(ib);
|
|
158
|
+
++histo[ib];
|
|
159
|
+
}
|
|
160
|
+
int nhave = 0;
|
|
161
|
+
int ib = nbuckets - 1;
|
|
162
|
+
for ( ; ib >= 0; --ib) {
|
|
163
|
+
nhave += histo[ib];
|
|
164
|
+
if (nhave >= npartial) {
|
|
165
|
+
break;
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
res.resize(nhave);
|
|
169
|
+
auto * ptr = res.data();
|
|
170
|
+
bucket_ptrs.reserve(nbuckets - ib);
|
|
171
|
+
for (int j = nbuckets - 1; j >= ib; --j) {
|
|
172
|
+
bucket_ptrs.push_back(ptr);
|
|
173
|
+
ptr += histo[j];
|
|
174
|
+
}
|
|
175
|
+
for (int i = 0; i < (int)cur.size; ++i) {
|
|
176
|
+
int j = bucket_idx[i];
|
|
177
|
+
if (j >= ib) {
|
|
178
|
+
*bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
ptr = res.data();
|
|
183
|
+
int ndone = 0;
|
|
184
|
+
for (int j = nbuckets - 1; j > ib; --j) {
|
|
185
|
+
std::sort(ptr, ptr + histo[j], comp);
|
|
186
|
+
ptr += histo[j];
|
|
187
|
+
ndone += histo[j];
|
|
188
|
+
}
|
|
189
|
+
std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// reduces the size of cur_p to npartial, keeping only the top npartial elements
|
|
193
|
+
static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
|
|
194
|
+
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
|
195
|
+
return a.logit > b.logit;
|
|
196
|
+
};
|
|
197
|
+
|
|
198
|
+
if (npartial <= 128) {
|
|
199
|
+
std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
|
|
200
|
+
|
|
201
|
+
cur_p->size = npartial;
|
|
202
|
+
cur_p->sorted = true;
|
|
203
|
+
|
|
204
|
+
return;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
std::vector<llama_token_data> tmp;
|
|
208
|
+
|
|
209
|
+
llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
|
|
210
|
+
|
|
211
|
+
std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
|
|
212
|
+
|
|
213
|
+
cur_p->size = npartial;
|
|
214
|
+
cur_p->sorted = true;
|
|
215
|
+
}
|
|
216
|
+
|
|
131
217
|
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
|
132
218
|
// iterator for the probabilities
|
|
133
219
|
#ifdef __GNUC__
|
|
@@ -200,18 +286,21 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp)
|
|
|
200
286
|
}
|
|
201
287
|
}
|
|
202
288
|
|
|
203
|
-
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
|
289
|
+
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
|
|
204
290
|
GGML_ASSERT(cur_p->size > 0);
|
|
205
291
|
|
|
206
|
-
// Sort the logits in descending order
|
|
207
|
-
if (!cur_p->sorted) {
|
|
208
|
-
|
|
209
|
-
return a.logit > b.logit;
|
|
210
|
-
});
|
|
211
|
-
cur_p->sorted = true;
|
|
292
|
+
// Sort the logits in descending order if requested
|
|
293
|
+
if (do_sort && !cur_p->sorted) {
|
|
294
|
+
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
|
212
295
|
}
|
|
213
296
|
|
|
214
297
|
float max_l = cur_p->data[0].logit;
|
|
298
|
+
if (!cur_p->sorted) {
|
|
299
|
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
|
300
|
+
max_l = std::max(max_l, cur_p->data[i].logit);
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
|
|
215
304
|
float cum_sum = 0.0f;
|
|
216
305
|
|
|
217
306
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
@@ -226,7 +315,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
|
|
226
315
|
}
|
|
227
316
|
|
|
228
317
|
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
|
229
|
-
// TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
|
|
230
318
|
// if (k >= (int32_t)cur_p->size) {
|
|
231
319
|
// return;
|
|
232
320
|
// }
|
|
@@ -239,64 +327,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|
|
239
327
|
|
|
240
328
|
// Sort scores in descending order
|
|
241
329
|
if (!cur_p->sorted) {
|
|
242
|
-
|
|
243
|
-
return a.logit > b.logit;
|
|
244
|
-
};
|
|
245
|
-
if (k <= 128) {
|
|
246
|
-
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
|
|
247
|
-
} else {
|
|
248
|
-
constexpr int nbuckets = 128;
|
|
249
|
-
constexpr float bucket_low = -10.0f;
|
|
250
|
-
constexpr float bucket_high = 10.0f;
|
|
251
|
-
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
|
252
|
-
constexpr float bucket_inter = -bucket_low * bucket_scale;
|
|
253
|
-
|
|
254
|
-
std::vector<int> bucket_idx(cur_p->size);
|
|
255
|
-
std::vector<int> histo(nbuckets, 0);
|
|
256
|
-
|
|
257
|
-
for (int i = 0; i < (int)cur_p->size; ++i) {
|
|
258
|
-
const float val = cur_p->data[i].logit;
|
|
259
|
-
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
|
260
|
-
ib = std::max(0, std::min(nbuckets - 1, ib));
|
|
261
|
-
bucket_idx[i] = ib;
|
|
262
|
-
++histo[ib];
|
|
263
|
-
}
|
|
264
|
-
int nhave = 0;
|
|
265
|
-
int ib = nbuckets - 1;
|
|
266
|
-
for ( ; ib >= 0; --ib) {
|
|
267
|
-
nhave += histo[ib];
|
|
268
|
-
if (nhave >= k) {
|
|
269
|
-
break;
|
|
270
|
-
}
|
|
271
|
-
}
|
|
272
|
-
std::vector<llama_token_data> tmp_tokens(nhave);
|
|
273
|
-
auto * ptr = tmp_tokens.data();
|
|
274
|
-
std::vector<llama_token_data*> bucket_ptrs;
|
|
275
|
-
bucket_ptrs.reserve(nbuckets - ib);
|
|
276
|
-
for (int j = nbuckets - 1; j >= ib; --j) {
|
|
277
|
-
bucket_ptrs.push_back(ptr);
|
|
278
|
-
ptr += histo[j];
|
|
279
|
-
}
|
|
280
|
-
for (int i = 0; i < (int)cur_p->size; ++i) {
|
|
281
|
-
int j = bucket_idx[i];
|
|
282
|
-
if (j >= ib) {
|
|
283
|
-
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
|
|
284
|
-
}
|
|
285
|
-
}
|
|
286
|
-
|
|
287
|
-
ptr = tmp_tokens.data();
|
|
288
|
-
int ndone = 0;
|
|
289
|
-
for (int j = nbuckets - 1; j > ib; --j) {
|
|
290
|
-
std::sort(ptr, ptr + histo[j], comp);
|
|
291
|
-
ptr += histo[j];
|
|
292
|
-
ndone += histo[j];
|
|
293
|
-
}
|
|
294
|
-
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
|
295
|
-
|
|
296
|
-
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
|
297
|
-
|
|
298
|
-
}
|
|
299
|
-
cur_p->sorted = true;
|
|
330
|
+
llama_token_data_array_partial_sort_inplace(cur_p, k);
|
|
300
331
|
}
|
|
301
332
|
|
|
302
333
|
cur_p->size = k;
|
|
@@ -317,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
|
|
317
348
|
|
|
318
349
|
// llama_sampler API
|
|
319
350
|
|
|
320
|
-
struct llama_sampler * llama_sampler_init(
|
|
351
|
+
struct llama_sampler * llama_sampler_init(
|
|
352
|
+
struct llama_sampler_i * iface,
|
|
353
|
+
llama_sampler_context_t ctx) {
|
|
321
354
|
return new llama_sampler {
|
|
322
355
|
/* .iface = */ iface,
|
|
323
356
|
/* .ctx = */ ctx,
|
|
@@ -333,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
|
|
333
366
|
}
|
|
334
367
|
|
|
335
368
|
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|
369
|
+
if (!smpl) {
|
|
370
|
+
return;
|
|
371
|
+
}
|
|
372
|
+
|
|
336
373
|
if (smpl->iface->accept) {
|
|
337
374
|
smpl->iface->accept(smpl, token);
|
|
338
375
|
}
|
|
339
376
|
}
|
|
340
377
|
|
|
341
378
|
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
|
379
|
+
if (!smpl) {
|
|
380
|
+
return;
|
|
381
|
+
}
|
|
382
|
+
|
|
342
383
|
GGML_ASSERT(smpl->iface->apply);
|
|
343
384
|
smpl->iface->apply(smpl, cur_p);
|
|
344
385
|
}
|
|
345
386
|
|
|
346
387
|
void llama_sampler_reset(struct llama_sampler * smpl) {
|
|
388
|
+
if (!smpl) {
|
|
389
|
+
return;
|
|
390
|
+
}
|
|
391
|
+
|
|
347
392
|
if (smpl->iface->reset) {
|
|
348
393
|
smpl->iface->reset(smpl);
|
|
349
394
|
}
|
|
350
395
|
}
|
|
351
396
|
|
|
352
397
|
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
|
398
|
+
if (!smpl) {
|
|
399
|
+
return nullptr;
|
|
400
|
+
}
|
|
401
|
+
|
|
353
402
|
if (smpl->iface->clone) {
|
|
354
403
|
return smpl->iface->clone(smpl);
|
|
355
404
|
}
|
|
@@ -376,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|
|
376
425
|
delete smpl;
|
|
377
426
|
}
|
|
378
427
|
|
|
379
|
-
|
|
380
|
-
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
428
|
+
// empty sampler
|
|
381
429
|
|
|
382
|
-
|
|
383
|
-
const
|
|
430
|
+
struct llama_sampler_empty {
|
|
431
|
+
const char * name;
|
|
432
|
+
};
|
|
384
433
|
|
|
385
|
-
|
|
434
|
+
static struct llama_sampler * llama_sampler_init_empty(const char * name);
|
|
435
|
+
|
|
436
|
+
static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
|
|
437
|
+
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
|
438
|
+
return ctx->name;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
|
|
442
|
+
GGML_UNUSED(smpl);
|
|
443
|
+
GGML_UNUSED(token);
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
447
|
+
GGML_UNUSED(smpl);
|
|
448
|
+
GGML_UNUSED(cur_p);
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
|
|
452
|
+
GGML_UNUSED(smpl);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
|
|
456
|
+
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
|
457
|
+
return llama_sampler_init_empty(ctx->name);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
static void llama_sampler_empty_free(struct llama_sampler * smpl) {
|
|
461
|
+
delete (llama_sampler_empty *) smpl->ctx;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
static bool llama_sampler_empty_backend_init(
|
|
465
|
+
struct llama_sampler * smpl,
|
|
466
|
+
ggml_backend_buffer_type_t buft) {
|
|
467
|
+
GGML_UNUSED(smpl);
|
|
468
|
+
GGML_UNUSED(buft);
|
|
469
|
+
|
|
470
|
+
return true;
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
static void llama_sampler_empty_backend_accept(
|
|
474
|
+
struct llama_sampler * smpl,
|
|
475
|
+
ggml_context * ctx,
|
|
476
|
+
ggml_cgraph * gf,
|
|
477
|
+
struct ggml_tensor * selected_token) {
|
|
478
|
+
GGML_UNUSED(smpl);
|
|
479
|
+
GGML_UNUSED(ctx);
|
|
480
|
+
GGML_UNUSED(gf);
|
|
481
|
+
GGML_UNUSED(selected_token);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
static void llama_sampler_empty_backend_apply(
|
|
485
|
+
struct llama_sampler * smpl,
|
|
486
|
+
struct ggml_context * ctx,
|
|
487
|
+
struct ggml_cgraph * gf,
|
|
488
|
+
struct llama_sampler_data * data) {
|
|
489
|
+
GGML_UNUSED(smpl);
|
|
490
|
+
GGML_UNUSED(ctx);
|
|
491
|
+
GGML_UNUSED(gf);
|
|
492
|
+
GGML_UNUSED(data);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
|
|
496
|
+
GGML_UNUSED(smpl);
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
static struct llama_sampler_i llama_sampler_empty_i = {
|
|
500
|
+
/* .name = */ llama_sampler_empty_name,
|
|
501
|
+
/* .accept = */ llama_sampler_empty_accept,
|
|
502
|
+
/* .apply = */ llama_sampler_empty_apply,
|
|
503
|
+
/* .reset = */ llama_sampler_empty_reset,
|
|
504
|
+
/* .clone = */ llama_sampler_empty_clone,
|
|
505
|
+
/* .free = */ llama_sampler_empty_free,
|
|
506
|
+
/* .backend_init = */ llama_sampler_empty_backend_init,
|
|
507
|
+
/* .backend_accept = */ llama_sampler_empty_backend_accept,
|
|
508
|
+
/* .backend_apply = */ llama_sampler_empty_backend_apply,
|
|
509
|
+
/* .backend_set_input = */ llama_sampler_empty_backend_set_input,
|
|
510
|
+
};
|
|
511
|
+
|
|
512
|
+
struct llama_sampler * llama_sampler_init_empty(const char * name) {
|
|
513
|
+
return llama_sampler_init(
|
|
514
|
+
/* .iface = */ &llama_sampler_empty_i,
|
|
515
|
+
/* .ctx = */ new llama_sampler_empty {
|
|
516
|
+
/* .name = */ name,
|
|
517
|
+
}
|
|
518
|
+
);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
// common backend sampler functionality
|
|
522
|
+
//
|
|
523
|
+
// +name : means that the sampler is support and will run on the backend
|
|
524
|
+
// -name : means that a ggml operator is not supported by the backend
|
|
525
|
+
//
|
|
526
|
+
struct llama_sampler_backend {
|
|
527
|
+
llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
|
|
386
528
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
529
|
+
const char * get_name() {
|
|
530
|
+
if (!is_init) {
|
|
531
|
+
return name.c_str();
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
if (support) {
|
|
535
|
+
name_ext = "+" + name;
|
|
536
|
+
} else {
|
|
537
|
+
name_ext = "-" + name;
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
return name_ext.c_str();
|
|
392
541
|
}
|
|
393
542
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
543
|
+
void init(bool support) {
|
|
544
|
+
GGML_ASSERT(this->is_init == false);
|
|
545
|
+
|
|
546
|
+
this->is_init = true;
|
|
547
|
+
this->support = support;
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
private:
|
|
551
|
+
std::string name;
|
|
552
|
+
std::string name_ext;
|
|
553
|
+
|
|
554
|
+
bool is_init;
|
|
555
|
+
bool support;
|
|
556
|
+
};
|
|
557
|
+
|
|
558
|
+
// check if all ggml ops used by the sampler are supported by the backend
|
|
559
|
+
static bool llama_sampler_backend_support(
|
|
560
|
+
llama_sampler * smpl,
|
|
561
|
+
ggml_backend_buffer_type_t buft) {
|
|
562
|
+
auto * device = ggml_backend_buft_get_device(buft);
|
|
563
|
+
if (!device) {
|
|
564
|
+
// CPU backend always supported
|
|
565
|
+
return true;
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
ggml_init_params params = {
|
|
569
|
+
/*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
|
|
570
|
+
/*.mem_buffer =*/ NULL,
|
|
571
|
+
/*.no_alloc =*/ true,
|
|
399
572
|
};
|
|
400
573
|
|
|
401
|
-
|
|
574
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
575
|
+
if (!ctx_ptr) {
|
|
576
|
+
throw std::runtime_error(format("failed to create ggml context"));
|
|
577
|
+
}
|
|
402
578
|
|
|
403
|
-
|
|
579
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
404
580
|
|
|
405
|
-
|
|
581
|
+
const int64_t n = 1024*1024;
|
|
406
582
|
|
|
407
|
-
|
|
583
|
+
llama_sampler_data data = {
|
|
584
|
+
/*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
|
|
585
|
+
/*.probs = */ nullptr,
|
|
586
|
+
/*.sampled = */ nullptr,
|
|
587
|
+
/*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
|
|
588
|
+
};
|
|
408
589
|
|
|
409
|
-
|
|
590
|
+
ggml_cgraph * gf = ggml_new_graph(ctx);
|
|
591
|
+
|
|
592
|
+
smpl->iface->backend_apply(smpl, ctx, gf, &data);
|
|
593
|
+
|
|
594
|
+
if (data.logits) {
|
|
595
|
+
ggml_build_forward_expand(gf, data.logits);
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
if (data.probs) {
|
|
599
|
+
ggml_build_forward_expand(gf, data.probs);
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
if (data.sampled) {
|
|
603
|
+
ggml_build_forward_expand(gf, data.sampled);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
if (data.candidates) {
|
|
607
|
+
ggml_build_forward_expand(gf, data.candidates);
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
611
|
+
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
|
612
|
+
|
|
613
|
+
if (!ggml_backend_dev_supports_op(device, op)) {
|
|
614
|
+
LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
|
|
615
|
+
__func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
|
|
616
|
+
|
|
617
|
+
return false;
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
return true;
|
|
410
622
|
}
|
|
411
623
|
|
|
412
624
|
// sampler chain
|
|
@@ -420,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
|
|
|
420
632
|
|
|
421
633
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
422
634
|
|
|
423
|
-
for (auto
|
|
424
|
-
llama_sampler_accept(smpl, token);
|
|
635
|
+
for (auto & smpl : chain->samplers) {
|
|
636
|
+
llama_sampler_accept(smpl.ptr, token);
|
|
425
637
|
}
|
|
426
638
|
|
|
427
639
|
chain->n_sample++;
|
|
@@ -432,20 +644,29 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
|
|
|
432
644
|
|
|
433
645
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
|
434
646
|
|
|
435
|
-
|
|
436
|
-
|
|
647
|
+
bool is_backend = chain->is_init;
|
|
648
|
+
|
|
649
|
+
for (auto & smpl : chain->samplers) {
|
|
650
|
+
if (is_backend && smpl.is_backend) {
|
|
651
|
+
continue;
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
is_backend = false;
|
|
655
|
+
|
|
656
|
+
if (smpl.ptr->iface->apply == nullptr) {
|
|
657
|
+
continue;
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
llama_sampler_apply(smpl.ptr, cur_p);
|
|
437
661
|
}
|
|
438
662
|
}
|
|
439
663
|
|
|
440
664
|
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
|
441
665
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
442
666
|
|
|
443
|
-
for (auto
|
|
444
|
-
llama_sampler_reset(smpl);
|
|
667
|
+
for (auto & smpl : chain->samplers) {
|
|
668
|
+
llama_sampler_reset(smpl.ptr);
|
|
445
669
|
}
|
|
446
|
-
|
|
447
|
-
chain->t_sample_us = 0;
|
|
448
|
-
chain->n_sample = 0;
|
|
449
670
|
}
|
|
450
671
|
|
|
451
672
|
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
|
@@ -453,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
|
|
453
674
|
|
|
454
675
|
auto * result = llama_sampler_chain_init(chain_src->params);
|
|
455
676
|
|
|
456
|
-
for (auto
|
|
457
|
-
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
|
677
|
+
for (const auto & smpl : chain_src->samplers) {
|
|
678
|
+
llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
|
|
458
679
|
}
|
|
459
680
|
|
|
460
681
|
return result;
|
|
@@ -463,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
|
|
463
684
|
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
|
464
685
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
465
686
|
|
|
466
|
-
for (auto
|
|
467
|
-
llama_sampler_free(smpl);
|
|
687
|
+
for (auto & smpl : chain->samplers) {
|
|
688
|
+
llama_sampler_free(smpl.ptr);
|
|
468
689
|
}
|
|
469
690
|
|
|
470
691
|
delete chain;
|
|
471
692
|
}
|
|
472
693
|
|
|
694
|
+
static bool llama_sampler_chain_backend_init(
|
|
695
|
+
struct llama_sampler * smpl,
|
|
696
|
+
ggml_backend_buffer_type_t buft) {
|
|
697
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
698
|
+
|
|
699
|
+
GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
|
|
700
|
+
|
|
701
|
+
chain->is_init = true;
|
|
702
|
+
|
|
703
|
+
bool res = true;
|
|
704
|
+
|
|
705
|
+
for (auto & smpl : chain->samplers) {
|
|
706
|
+
bool res_cur = true;
|
|
707
|
+
|
|
708
|
+
// to be able to run a sampler on the backend, it has to:
|
|
709
|
+
// - have the .backend_init() API implemented
|
|
710
|
+
// - return true during .backend_init()
|
|
711
|
+
if (smpl.ptr->iface->backend_init) {
|
|
712
|
+
if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
|
|
713
|
+
res_cur = false;
|
|
714
|
+
}
|
|
715
|
+
} else {
|
|
716
|
+
res_cur = false;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
smpl.is_backend = res_cur;
|
|
720
|
+
|
|
721
|
+
res = res && res_cur;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
return res;
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
static void llama_sampler_chain_backend_accept(
|
|
728
|
+
struct llama_sampler * smpl,
|
|
729
|
+
ggml_context * ctx,
|
|
730
|
+
ggml_cgraph * gf,
|
|
731
|
+
struct ggml_tensor * selected_token) {
|
|
732
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
733
|
+
|
|
734
|
+
for (auto & smpl : chain->samplers) {
|
|
735
|
+
if (!smpl.is_backend) {
|
|
736
|
+
break;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
if (smpl.ptr->iface->backend_accept) {
|
|
740
|
+
smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
|
|
741
|
+
}
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
static void llama_sampler_chain_backend_apply(
|
|
746
|
+
struct llama_sampler * smpl,
|
|
747
|
+
struct ggml_context * ctx,
|
|
748
|
+
struct ggml_cgraph * gf,
|
|
749
|
+
struct llama_sampler_data * data) {
|
|
750
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
751
|
+
|
|
752
|
+
GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
|
|
753
|
+
|
|
754
|
+
for (auto & smpl : chain->samplers) {
|
|
755
|
+
if (!smpl.is_backend) {
|
|
756
|
+
break;
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
if (smpl.ptr->iface->backend_apply) {
|
|
760
|
+
smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
|
|
766
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
767
|
+
|
|
768
|
+
for (auto & smpl : chain->samplers) {
|
|
769
|
+
if (!smpl.is_backend) {
|
|
770
|
+
break;
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
if (smpl.ptr->iface->backend_set_input) {
|
|
774
|
+
smpl.ptr->iface->backend_set_input(smpl.ptr);
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
|
|
473
779
|
static struct llama_sampler_i llama_sampler_chain_i = {
|
|
474
|
-
/* .name
|
|
475
|
-
/* .accept
|
|
476
|
-
/* .apply
|
|
477
|
-
/* .reset
|
|
478
|
-
/* .clone
|
|
479
|
-
/* .free
|
|
780
|
+
/* .name = */ llama_sampler_chain_name,
|
|
781
|
+
/* .accept = */ llama_sampler_chain_accept,
|
|
782
|
+
/* .apply = */ llama_sampler_chain_apply,
|
|
783
|
+
/* .reset = */ llama_sampler_chain_reset,
|
|
784
|
+
/* .clone = */ llama_sampler_chain_clone,
|
|
785
|
+
/* .free = */ llama_sampler_chain_free,
|
|
786
|
+
/* .backend_init = */ llama_sampler_chain_backend_init,
|
|
787
|
+
/* .backend_accept = */ llama_sampler_chain_backend_accept,
|
|
788
|
+
/* .backend_apply = */ llama_sampler_chain_backend_apply,
|
|
789
|
+
/* .backend_set_input = */ llama_sampler_chain_backend_set_input,
|
|
480
790
|
};
|
|
481
791
|
|
|
482
792
|
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
|
@@ -484,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
|
|
484
794
|
/* .iface = */ &llama_sampler_chain_i,
|
|
485
795
|
/* .ctx = */ new llama_sampler_chain {
|
|
486
796
|
/* .params = */ params,
|
|
797
|
+
/* .is_init = */ false,
|
|
487
798
|
/* .samplers = */ {},
|
|
799
|
+
/* .cur = */ {},
|
|
488
800
|
/* .t_sample_us = */ 0,
|
|
489
801
|
/* .n_sample = */ 0,
|
|
490
802
|
}
|
|
491
803
|
);
|
|
492
804
|
}
|
|
493
805
|
|
|
806
|
+
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
|
807
|
+
const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
|
|
808
|
+
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
|
809
|
+
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
|
810
|
+
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
|
811
|
+
|
|
812
|
+
// If a backend sampler has already sampled a token, return it.
|
|
813
|
+
if (sampled_token != LLAMA_TOKEN_NULL) {
|
|
814
|
+
LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
|
|
815
|
+
return sampled_token;
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
const llama_model * model = llama_get_model(ctx);
|
|
819
|
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
820
|
+
|
|
821
|
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
822
|
+
|
|
823
|
+
// use pre-allocated buffer from chain if available, otherwise allocate locally
|
|
824
|
+
std::vector<llama_token_data> * cur_ptr;
|
|
825
|
+
std::vector<llama_token_data> cur_local;
|
|
826
|
+
|
|
827
|
+
if (smpl->iface == &llama_sampler_chain_i) {
|
|
828
|
+
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
|
829
|
+
cur_ptr = &chain->cur;
|
|
830
|
+
} else {
|
|
831
|
+
cur_ptr = &cur_local;
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
auto & cur = *cur_ptr;
|
|
835
|
+
|
|
836
|
+
if (sampled_probs) {
|
|
837
|
+
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
|
838
|
+
cur.resize(sampled_probs_count);
|
|
839
|
+
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
|
840
|
+
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
|
841
|
+
}
|
|
842
|
+
} else if (sampled_logits) {
|
|
843
|
+
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
|
844
|
+
cur.resize(sampled_logits_count);
|
|
845
|
+
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
|
846
|
+
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
|
847
|
+
}
|
|
848
|
+
} else {
|
|
849
|
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
|
850
|
+
GGML_ASSERT(logits != nullptr);
|
|
851
|
+
cur.resize(n_vocab);
|
|
852
|
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
853
|
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
llama_token_data_array cur_p = {
|
|
858
|
+
/* .data = */ cur.data(),
|
|
859
|
+
/* .size = */ cur.size(),
|
|
860
|
+
/* .selected = */ -1,
|
|
861
|
+
/* .sorted = */ false,
|
|
862
|
+
};
|
|
863
|
+
|
|
864
|
+
llama_sampler_apply(smpl, &cur_p);
|
|
865
|
+
|
|
866
|
+
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
|
867
|
+
|
|
868
|
+
auto token = cur_p.data[cur_p.selected].id;
|
|
869
|
+
|
|
870
|
+
llama_sampler_accept(smpl, token);
|
|
871
|
+
|
|
872
|
+
return token;
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
|
|
494
876
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
495
877
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
|
496
|
-
p->samplers.push_back(
|
|
878
|
+
p->samplers.push_back({
|
|
879
|
+
/* .is_backend = */ false,
|
|
880
|
+
/* .ptr = */ smpl,
|
|
881
|
+
});
|
|
497
882
|
}
|
|
498
883
|
|
|
499
|
-
struct llama_sampler * llama_sampler_chain_get(
|
|
884
|
+
struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
|
|
885
|
+
if (chain == nullptr) {
|
|
886
|
+
return nullptr;
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
if (chain->iface != &llama_sampler_chain_i) {
|
|
890
|
+
return nullptr;
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
if (i == -1) {
|
|
894
|
+
return chain;
|
|
895
|
+
}
|
|
896
|
+
|
|
500
897
|
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
|
501
898
|
|
|
502
899
|
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
|
503
900
|
return nullptr;
|
|
504
901
|
}
|
|
505
902
|
|
|
506
|
-
return p->samplers[i];
|
|
903
|
+
return p->samplers[i].ptr;
|
|
507
904
|
}
|
|
508
905
|
|
|
509
906
|
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
|
@@ -513,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
|
|
|
513
910
|
return nullptr;
|
|
514
911
|
}
|
|
515
912
|
|
|
516
|
-
auto * result = p->samplers[i];
|
|
913
|
+
auto * result = p->samplers[i].ptr;
|
|
517
914
|
p->samplers.erase(p->samplers.begin() + i);
|
|
518
915
|
|
|
519
916
|
return result;
|
|
@@ -531,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
|
|
531
928
|
|
|
532
929
|
// greedy
|
|
533
930
|
|
|
534
|
-
|
|
535
|
-
|
|
931
|
+
struct llama_sampler_greedy : public llama_sampler_backend {
|
|
932
|
+
};
|
|
933
|
+
|
|
934
|
+
static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
|
|
935
|
+
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
|
936
|
+
return sctx->get_name();
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
|
|
940
|
+
auto * ctx = (llama_sampler_greedy *) smpl->ctx;
|
|
941
|
+
GGML_UNUSED(ctx);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
|
|
945
|
+
const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
|
|
946
|
+
auto * result = llama_sampler_init_greedy();
|
|
947
|
+
|
|
948
|
+
// copy the state
|
|
949
|
+
{
|
|
950
|
+
auto * result_ctx = (llama_sampler_greedy *) result->ctx;
|
|
951
|
+
|
|
952
|
+
GGML_UNUSED(ctx);
|
|
953
|
+
GGML_UNUSED(result_ctx);
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
return result;
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
|
|
960
|
+
delete (llama_sampler_greedy *) smpl->ctx;
|
|
536
961
|
}
|
|
537
962
|
|
|
538
963
|
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
|
@@ -544,41 +969,150 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
|
|
544
969
|
}
|
|
545
970
|
}
|
|
546
971
|
|
|
972
|
+
static bool llama_sampler_greedy_backend_init(
|
|
973
|
+
struct llama_sampler * smpl,
|
|
974
|
+
ggml_backend_buffer_type_t buft) {
|
|
975
|
+
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
|
976
|
+
|
|
977
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
978
|
+
|
|
979
|
+
sctx->init(res);
|
|
980
|
+
|
|
981
|
+
return res;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
static void llama_sampler_greedy_backend_apply(
|
|
985
|
+
struct llama_sampler * smpl,
|
|
986
|
+
struct ggml_context * ctx,
|
|
987
|
+
struct ggml_cgraph * gf,
|
|
988
|
+
struct llama_sampler_data * data) {
|
|
989
|
+
GGML_UNUSED(gf);
|
|
990
|
+
GGML_UNUSED(smpl);
|
|
991
|
+
|
|
992
|
+
struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
|
|
993
|
+
ggml_set_name(curl, "greedy_argmax");
|
|
994
|
+
|
|
995
|
+
data->sampled = curl;
|
|
996
|
+
}
|
|
997
|
+
|
|
547
998
|
static struct llama_sampler_i llama_sampler_greedy_i = {
|
|
548
|
-
/* .name
|
|
549
|
-
/* .accept
|
|
550
|
-
/* .apply
|
|
551
|
-
/* .reset
|
|
552
|
-
/* .clone
|
|
553
|
-
/* .free
|
|
999
|
+
/* .name = */ llama_sampler_greedy_name,
|
|
1000
|
+
/* .accept = */ nullptr,
|
|
1001
|
+
/* .apply = */ llama_sampler_greedy_apply,
|
|
1002
|
+
/* .reset = */ llama_sampler_greedy_reset,
|
|
1003
|
+
/* .clone = */ llama_sampler_greedy_clone,
|
|
1004
|
+
/* .free = */ llama_sampler_greedy_free,
|
|
1005
|
+
/* .backend_init = */ llama_sampler_greedy_backend_init,
|
|
1006
|
+
/* .backend_accept = */ nullptr,
|
|
1007
|
+
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
|
|
1008
|
+
/* .backend_set_input = */ nullptr,
|
|
554
1009
|
};
|
|
555
1010
|
|
|
556
1011
|
struct llama_sampler * llama_sampler_init_greedy() {
|
|
557
1012
|
return llama_sampler_init(
|
|
558
1013
|
/* .iface = */ &llama_sampler_greedy_i,
|
|
559
|
-
/* .ctx = */
|
|
1014
|
+
/* .ctx = */ new llama_sampler_greedy {
|
|
1015
|
+
("greedy"),
|
|
1016
|
+
}
|
|
560
1017
|
);
|
|
561
1018
|
}
|
|
562
1019
|
|
|
563
1020
|
// dist
|
|
564
1021
|
|
|
565
|
-
struct llama_sampler_dist {
|
|
1022
|
+
struct llama_sampler_dist : public llama_sampler_backend {
|
|
566
1023
|
const uint32_t seed;
|
|
567
1024
|
uint32_t seed_cur;
|
|
568
1025
|
|
|
569
|
-
std::mt19937 rng;
|
|
570
|
-
|
|
1026
|
+
std::mt19937 rng;
|
|
1027
|
+
|
|
1028
|
+
// backend input
|
|
1029
|
+
struct ggml_tensor * inp_uniform;
|
|
1030
|
+
|
|
1031
|
+
ggml_context_ptr inp_ctx;
|
|
1032
|
+
ggml_backend_buffer_ptr inp_buf;
|
|
1033
|
+
};
|
|
1034
|
+
|
|
1035
|
+
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
|
|
1036
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1037
|
+
return sctx->get_name();
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1041
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
1042
|
+
|
|
1043
|
+
// edge cases
|
|
1044
|
+
if (cur_p->size == 0) {
|
|
1045
|
+
cur_p->selected = -1;
|
|
1046
|
+
return;
|
|
1047
|
+
}
|
|
1048
|
+
|
|
1049
|
+
cur_p->selected = 0;
|
|
1050
|
+
|
|
1051
|
+
if (cur_p->size == 1) {
|
|
1052
|
+
cur_p->data[0].p = 1.0f;
|
|
1053
|
+
return;
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
// max logit for numerical stability
|
|
1057
|
+
float max_l = cur_p->data[0].logit;
|
|
1058
|
+
if (!cur_p->sorted) {
|
|
1059
|
+
for (size_t i = 1; i < cur_p->size; ++i) {
|
|
1060
|
+
max_l = std::max(max_l, cur_p->data[i].logit);
|
|
1061
|
+
}
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
// apply softmax to obtain the probabilities
|
|
1065
|
+
double sum_cum = 0.0f;
|
|
1066
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1067
|
+
float p = expf(cur_p->data[i].logit - max_l);
|
|
1068
|
+
cur_p->data[i].p = p;
|
|
1069
|
+
sum_cum += p;
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
#if 1
|
|
1073
|
+
// sample from the obtained probabilities and normalize the probs in a single pass
|
|
1074
|
+
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
|
|
1075
|
+
//
|
|
1076
|
+
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
|
1077
|
+
const double rnd = dist(ctx->rng);
|
|
1078
|
+
|
|
1079
|
+
double sum_run = 0.0f;
|
|
1080
|
+
const double sum_tgt = sum_cum*rnd;
|
|
571
1081
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
1082
|
+
bool found = false;
|
|
1083
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1084
|
+
if (!found) {
|
|
1085
|
+
// accumulate probs until we reach the target sum
|
|
1086
|
+
sum_run += cur_p->data[i].p;
|
|
1087
|
+
if (sum_run >= sum_tgt) {
|
|
1088
|
+
cur_p->selected = i;
|
|
1089
|
+
found = true;
|
|
1090
|
+
}
|
|
1091
|
+
}
|
|
575
1092
|
|
|
576
|
-
|
|
577
|
-
|
|
1093
|
+
// normalize probs
|
|
1094
|
+
cur_p->data[i].p /= sum_cum;
|
|
1095
|
+
}
|
|
578
1096
|
|
|
579
|
-
|
|
1097
|
+
// fallback to the last token (don't think this can happen)
|
|
1098
|
+
assert(found);
|
|
1099
|
+
if (!found) {
|
|
1100
|
+
cur_p->selected = cur_p->size - 1;
|
|
1101
|
+
}
|
|
1102
|
+
#else
|
|
1103
|
+
// for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
|
|
1104
|
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1105
|
+
cur_p->data[i].p /= sum_cum;
|
|
1106
|
+
}
|
|
580
1107
|
|
|
581
1108
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
|
1109
|
+
#endif
|
|
1110
|
+
}
|
|
1111
|
+
|
|
1112
|
+
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
1113
|
+
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
1114
|
+
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
1115
|
+
ctx->rng.seed(ctx->seed_cur);
|
|
582
1116
|
}
|
|
583
1117
|
|
|
584
1118
|
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
|
@@ -595,75 +1129,158 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
|
|
|
595
1129
|
return result;
|
|
596
1130
|
}
|
|
597
1131
|
|
|
598
|
-
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
|
599
|
-
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
|
600
|
-
ctx->seed_cur = get_rng_seed(ctx->seed);
|
|
601
|
-
ctx->rng.seed(ctx->seed_cur);
|
|
602
|
-
}
|
|
603
|
-
|
|
604
1132
|
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|
605
1133
|
delete (llama_sampler_dist *) smpl->ctx;
|
|
606
1134
|
}
|
|
607
1135
|
|
|
608
|
-
static
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
/* .reset = */ llama_sampler_dist_reset,
|
|
613
|
-
/* .clone = */ llama_sampler_dist_clone,
|
|
614
|
-
/* .free = */ llama_sampler_dist_free,
|
|
615
|
-
};
|
|
1136
|
+
static bool llama_sampler_dist_backend_init(
|
|
1137
|
+
struct llama_sampler * smpl,
|
|
1138
|
+
ggml_backend_buffer_type_t buft) {
|
|
1139
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
616
1140
|
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
1141
|
+
// allocate inputs
|
|
1142
|
+
{
|
|
1143
|
+
ggml_init_params params = {
|
|
1144
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
1145
|
+
/*.mem_buffer =*/ nullptr,
|
|
1146
|
+
/*.no_alloc =*/ true,
|
|
1147
|
+
};
|
|
1148
|
+
|
|
1149
|
+
sctx->inp_ctx.reset(ggml_init(params));
|
|
1150
|
+
|
|
1151
|
+
// Create the uniform random scalar input tensor. This will be set by
|
|
1152
|
+
// llama_sampler_dist_backend_set_input after this graph is built.
|
|
1153
|
+
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
|
|
1154
|
+
ggml_set_name (sctx->inp_uniform, "uniform");
|
|
1155
|
+
ggml_set_input(sctx->inp_uniform);
|
|
1156
|
+
|
|
1157
|
+
// Allocate all tensors from our context to the backend
|
|
1158
|
+
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
|
1159
|
+
|
|
1160
|
+
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1164
|
+
|
|
1165
|
+
sctx->init(res);
|
|
1166
|
+
|
|
1167
|
+
if (!res) {
|
|
1168
|
+
sctx->inp_ctx.reset(nullptr);
|
|
1169
|
+
sctx->inp_buf.reset(nullptr);
|
|
1170
|
+
}
|
|
1171
|
+
|
|
1172
|
+
return res;
|
|
627
1173
|
}
|
|
628
1174
|
|
|
629
|
-
|
|
1175
|
+
static void llama_sampler_dist_backend_apply(
|
|
1176
|
+
struct llama_sampler * smpl,
|
|
1177
|
+
struct ggml_context * ctx,
|
|
1178
|
+
struct ggml_cgraph * gf,
|
|
1179
|
+
struct llama_sampler_data * data) {
|
|
1180
|
+
GGML_UNUSED(gf);
|
|
1181
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1182
|
+
|
|
1183
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
1184
|
+
ggml_set_name(probs, "dist_probs");
|
|
1185
|
+
|
|
1186
|
+
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
|
1187
|
+
ggml_set_name(cumsum, "dist_cumsum");
|
|
1188
|
+
|
|
1189
|
+
// The uniform tensor has a random value and we subtract this tensor with
|
|
1190
|
+
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
|
|
1191
|
+
// Recall that each entry in cumsum is the cumulative probability up to that
|
|
1192
|
+
// index so values stay negative while the cumulative total is below the
|
|
1193
|
+
// random value, and become zero/positive once the threshold is crossed.
|
|
1194
|
+
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
|
|
1195
|
+
ggml_set_name(diff, "dist_cumsum");
|
|
1196
|
+
|
|
1197
|
+
// The ggml_step function produces a tensor where entries are 1 if the
|
|
1198
|
+
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
|
|
1199
|
+
// the index where the cumulative probability exceeds the random value are 0,
|
|
1200
|
+
// and all entries after that are 1.
|
|
1201
|
+
struct ggml_tensor * mask = ggml_step(ctx, diff);
|
|
1202
|
+
ggml_set_name(mask, "dist_mask");
|
|
1203
|
+
|
|
1204
|
+
// Taking the sum of the mask gives us the sum of elements after the threshold
|
|
1205
|
+
// we are interested in.
|
|
1206
|
+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
|
1207
|
+
ggml_set_name(idxf, "dist_index_f32");
|
|
630
1208
|
|
|
631
|
-
|
|
632
|
-
|
|
1209
|
+
// Use ggml_scale_bias to scale the index value by -1 and then add the size
|
|
1210
|
+
// of the mask to that value so we get the correct index ((-1 * idxf) + n).
|
|
1211
|
+
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
|
1212
|
+
ggml_set_name(idx, "dist_index_i32");
|
|
1213
|
+
|
|
1214
|
+
// Map back to original vocab ids if a candidates tensor is available.
|
|
1215
|
+
struct ggml_tensor * sampled_token = idx;
|
|
1216
|
+
if (data->candidates != nullptr) {
|
|
1217
|
+
struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
|
|
1218
|
+
|
|
1219
|
+
sampled_token = ggml_get_rows(ctx, candidates, idx);
|
|
1220
|
+
ggml_set_name(sampled_token, "dist_sampled_token");
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
data->sampled = sampled_token;
|
|
1224
|
+
data->probs = probs;
|
|
633
1225
|
}
|
|
634
1226
|
|
|
635
|
-
static void
|
|
636
|
-
|
|
1227
|
+
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
|
1228
|
+
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
|
1229
|
+
GGML_ASSERT(sctx->inp_uniform != nullptr);
|
|
1230
|
+
|
|
1231
|
+
// We sample in double precision and cast to float to match rnd numbers of
|
|
1232
|
+
// llama_dampler_dist which uses double precision (sampling from
|
|
1233
|
+
// std::uniform_real_distribution<double> and
|
|
1234
|
+
// std::uniform_real_distribution<float> with same rng will produce
|
|
1235
|
+
// different sequences).
|
|
1236
|
+
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
|
1237
|
+
const float rnd = dist(sctx->rng);
|
|
1238
|
+
|
|
1239
|
+
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
|
637
1240
|
}
|
|
638
1241
|
|
|
639
|
-
static struct llama_sampler_i
|
|
640
|
-
/* .name
|
|
641
|
-
/* .accept
|
|
642
|
-
/* .apply
|
|
643
|
-
/* .reset
|
|
644
|
-
/* .clone
|
|
645
|
-
/* .free
|
|
1242
|
+
static struct llama_sampler_i llama_sampler_dist_i = {
|
|
1243
|
+
/* .name = */ llama_sampler_dist_name,
|
|
1244
|
+
/* .accept = */ nullptr,
|
|
1245
|
+
/* .apply = */ llama_sampler_dist_apply,
|
|
1246
|
+
/* .reset = */ llama_sampler_dist_reset,
|
|
1247
|
+
/* .clone = */ llama_sampler_dist_clone,
|
|
1248
|
+
/* .free = */ llama_sampler_dist_free,
|
|
1249
|
+
/* .backend_init = */ llama_sampler_dist_backend_init,
|
|
1250
|
+
/* .backend_accept = */ nullptr,
|
|
1251
|
+
/* .backend_apply = */ llama_sampler_dist_backend_apply,
|
|
1252
|
+
/* .backend_set_input = */ llama_sampler_dist_backend_set_input,
|
|
646
1253
|
};
|
|
647
1254
|
|
|
648
|
-
struct llama_sampler *
|
|
1255
|
+
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
|
1256
|
+
auto seed_cur = get_rng_seed(seed);
|
|
649
1257
|
return llama_sampler_init(
|
|
650
|
-
/* .iface = */ &
|
|
651
|
-
/* .ctx = */
|
|
1258
|
+
/* .iface = */ &llama_sampler_dist_i,
|
|
1259
|
+
/* .ctx = */ new llama_sampler_dist {
|
|
1260
|
+
("dist"),
|
|
1261
|
+
/* .seed = */ seed,
|
|
1262
|
+
/* .seed_cur = */ seed_cur,
|
|
1263
|
+
/* .rng = */ std::mt19937(seed_cur),
|
|
1264
|
+
/* .inp_uniform = */ nullptr,
|
|
1265
|
+
/* .inp_ctx = */ nullptr,
|
|
1266
|
+
/* .inp_buf = */ nullptr,
|
|
1267
|
+
}
|
|
652
1268
|
);
|
|
653
1269
|
}
|
|
654
1270
|
|
|
655
1271
|
// top-k
|
|
656
1272
|
|
|
657
|
-
struct llama_sampler_top_k {
|
|
1273
|
+
struct llama_sampler_top_k : public llama_sampler_backend {
|
|
658
1274
|
const int32_t k;
|
|
659
1275
|
};
|
|
660
1276
|
|
|
661
|
-
static const char * llama_sampler_top_k_name(const struct llama_sampler *
|
|
662
|
-
|
|
1277
|
+
static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
|
|
1278
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1279
|
+
return sctx->get_name();
|
|
663
1280
|
}
|
|
664
1281
|
|
|
665
1282
|
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
666
|
-
|
|
1283
|
+
auto * ctx = (llama_sampler_top_k *) smpl->ctx;
|
|
667
1284
|
llama_sampler_top_k_impl(cur_p, ctx->k);
|
|
668
1285
|
}
|
|
669
1286
|
|
|
@@ -676,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
|
|
676
1293
|
delete (llama_sampler_top_k *) smpl->ctx;
|
|
677
1294
|
}
|
|
678
1295
|
|
|
1296
|
+
static bool llama_sampler_top_k_backend_init(
|
|
1297
|
+
struct llama_sampler * smpl,
|
|
1298
|
+
ggml_backend_buffer_type_t buft) {
|
|
1299
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1300
|
+
|
|
1301
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1302
|
+
|
|
1303
|
+
sctx->init(res);
|
|
1304
|
+
|
|
1305
|
+
return res;
|
|
1306
|
+
}
|
|
1307
|
+
|
|
1308
|
+
static void llama_sampler_top_k_backend_apply(
|
|
1309
|
+
struct llama_sampler * smpl,
|
|
1310
|
+
struct ggml_context * ctx,
|
|
1311
|
+
struct ggml_cgraph * gf,
|
|
1312
|
+
struct llama_sampler_data * data) {
|
|
1313
|
+
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
|
1314
|
+
|
|
1315
|
+
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
|
|
1316
|
+
ggml_set_name(top_k, "top_k");
|
|
1317
|
+
|
|
1318
|
+
if (data->candidates) {
|
|
1319
|
+
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
|
1320
|
+
data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
|
|
1321
|
+
data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
|
|
1322
|
+
ggml_set_name(data->candidates, "top_k_candidates");
|
|
1323
|
+
} else {
|
|
1324
|
+
data->candidates = top_k;
|
|
1325
|
+
}
|
|
1326
|
+
|
|
1327
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1328
|
+
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
|
1329
|
+
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
|
|
1330
|
+
ggml_set_name(top_k_rows, "top_k_rows");
|
|
1331
|
+
|
|
1332
|
+
GGML_UNUSED(gf);
|
|
1333
|
+
}
|
|
1334
|
+
|
|
679
1335
|
static struct llama_sampler_i llama_sampler_top_k_i = {
|
|
680
|
-
/* .name
|
|
681
|
-
/* .accept
|
|
682
|
-
/* .apply
|
|
683
|
-
/* .reset
|
|
684
|
-
/* .clone
|
|
685
|
-
/* .free
|
|
1336
|
+
/* .name = */ llama_sampler_top_k_name,
|
|
1337
|
+
/* .accept = */ nullptr,
|
|
1338
|
+
/* .apply = */ llama_sampler_top_k_apply,
|
|
1339
|
+
/* .reset = */ nullptr,
|
|
1340
|
+
/* .clone = */ llama_sampler_top_k_clone,
|
|
1341
|
+
/* .free = */ llama_sampler_top_k_free,
|
|
1342
|
+
/* .backend_init = */ llama_sampler_top_k_backend_init,
|
|
1343
|
+
/* .backend_accept = */ nullptr,
|
|
1344
|
+
/* .backend_apply = */ llama_sampler_top_k_backend_apply,
|
|
1345
|
+
/* .backend_set_input = */ nullptr,
|
|
686
1346
|
};
|
|
687
1347
|
|
|
688
1348
|
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
1349
|
+
const bool is_empty = (k <= 0);
|
|
1350
|
+
|
|
1351
|
+
if (is_empty) {
|
|
1352
|
+
return llama_sampler_init_empty("?top-k");
|
|
1353
|
+
}
|
|
1354
|
+
|
|
689
1355
|
return llama_sampler_init(
|
|
690
1356
|
/* .iface = */ &llama_sampler_top_k_i,
|
|
691
1357
|
/* .ctx = */ new llama_sampler_top_k {
|
|
1358
|
+
("top-k"),
|
|
692
1359
|
/* .k = */ k,
|
|
693
1360
|
}
|
|
694
1361
|
);
|
|
@@ -696,30 +1363,48 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
|
|
696
1363
|
|
|
697
1364
|
// top-p
|
|
698
1365
|
|
|
699
|
-
struct llama_sampler_top_p {
|
|
1366
|
+
struct llama_sampler_top_p : public llama_sampler_backend {
|
|
700
1367
|
const float p;
|
|
701
1368
|
const size_t min_keep;
|
|
1369
|
+
|
|
1370
|
+
std::vector<llama_token_data> buf_sort;
|
|
702
1371
|
};
|
|
703
1372
|
|
|
704
|
-
static const char * llama_sampler_top_p_name(const struct llama_sampler *
|
|
705
|
-
|
|
1373
|
+
static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
|
|
1374
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1375
|
+
return sctx->get_name();
|
|
706
1376
|
}
|
|
707
1377
|
|
|
708
1378
|
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
709
|
-
|
|
1379
|
+
auto * ctx = (llama_sampler_top_p *) smpl->ctx;
|
|
710
1380
|
|
|
711
1381
|
if (ctx->p >= 1.0f) {
|
|
712
1382
|
return;
|
|
713
1383
|
}
|
|
714
1384
|
|
|
715
|
-
llama_sampler_softmax_impl(cur_p);
|
|
1385
|
+
llama_sampler_softmax_impl(cur_p, false);
|
|
1386
|
+
|
|
1387
|
+
size_t k = cur_p->size;
|
|
1388
|
+
auto * pdata = cur_p->data;
|
|
1389
|
+
|
|
1390
|
+
auto & buf_sort = ctx->buf_sort;
|
|
1391
|
+
|
|
1392
|
+
// if not sorted, try adaptive top-k sorting
|
|
1393
|
+
if (!cur_p->sorted && cur_p->size > 1024) {
|
|
1394
|
+
k = std::min<size_t>(256, cur_p->size);
|
|
1395
|
+
llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
|
|
1396
|
+
pdata = buf_sort.data();
|
|
1397
|
+
} else if (!cur_p->sorted) {
|
|
1398
|
+
// small candidates -> sort inplace
|
|
1399
|
+
llama_token_data_array_partial_sort_inplace(cur_p, k);
|
|
1400
|
+
}
|
|
716
1401
|
|
|
717
1402
|
// Compute the cumulative probabilities
|
|
718
1403
|
float cum_sum = 0.0f;
|
|
719
1404
|
size_t last_idx = cur_p->size;
|
|
720
1405
|
|
|
721
1406
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
722
|
-
cum_sum +=
|
|
1407
|
+
cum_sum += pdata[i].p;
|
|
723
1408
|
|
|
724
1409
|
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
|
725
1410
|
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
|
@@ -727,9 +1412,21 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
|
|
|
727
1412
|
last_idx = i + 1;
|
|
728
1413
|
break;
|
|
729
1414
|
}
|
|
1415
|
+
|
|
1416
|
+
// we exceeded the current top-k heuristic -> increase k and continue
|
|
1417
|
+
if (!cur_p->sorted && i == k - 1) {
|
|
1418
|
+
k = cur_p->size;
|
|
1419
|
+
llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
|
|
1420
|
+
pdata = buf_sort.data();
|
|
1421
|
+
}
|
|
730
1422
|
}
|
|
731
1423
|
|
|
732
1424
|
// Resize the output vector to keep only the top-p tokens
|
|
1425
|
+
if (!cur_p->sorted) {
|
|
1426
|
+
std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
|
|
1427
|
+
cur_p->sorted = true;
|
|
1428
|
+
}
|
|
1429
|
+
|
|
733
1430
|
cur_p->size = last_idx;
|
|
734
1431
|
}
|
|
735
1432
|
|
|
@@ -742,38 +1439,139 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|
|
742
1439
|
delete (llama_sampler_top_p *) smpl->ctx;
|
|
743
1440
|
}
|
|
744
1441
|
|
|
1442
|
+
static bool llama_sampler_top_p_backend_init(
|
|
1443
|
+
struct llama_sampler * smpl,
|
|
1444
|
+
ggml_backend_buffer_type_t buft) {
|
|
1445
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1446
|
+
|
|
1447
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1448
|
+
|
|
1449
|
+
sctx->init(res);
|
|
1450
|
+
|
|
1451
|
+
return res;
|
|
1452
|
+
}
|
|
1453
|
+
|
|
1454
|
+
static void llama_sampler_top_p_backend_apply(
|
|
1455
|
+
struct llama_sampler * smpl,
|
|
1456
|
+
struct ggml_context * ctx,
|
|
1457
|
+
struct ggml_cgraph * gf,
|
|
1458
|
+
struct llama_sampler_data * data) {
|
|
1459
|
+
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
|
1460
|
+
|
|
1461
|
+
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
|
|
1462
|
+
GGML_ASSERT(ggml_nrows(a) == 1);
|
|
1463
|
+
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
|
|
1464
|
+
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
|
|
1465
|
+
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
|
|
1466
|
+
};
|
|
1467
|
+
|
|
1468
|
+
// Get the sorted logits in descending order.
|
|
1469
|
+
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
|
|
1470
|
+
ggml_set_name(sorted_idx, "top_p_sorted_idx");
|
|
1471
|
+
|
|
1472
|
+
// Do the sorting via reshape + get_rows
|
|
1473
|
+
struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
|
|
1474
|
+
ggml_set_name(sorted_logits, "top_p_sorted_logits");
|
|
1475
|
+
|
|
1476
|
+
struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
|
|
1477
|
+
ggml_set_name(softmax, "top_p_softmax");
|
|
1478
|
+
|
|
1479
|
+
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
|
|
1480
|
+
if (data->candidates) {
|
|
1481
|
+
data->candidates = ggml_sort(data->candidates, sorted_idx);
|
|
1482
|
+
} else {
|
|
1483
|
+
data->candidates = sorted_idx;
|
|
1484
|
+
}
|
|
1485
|
+
ggml_set_name(data->candidates, "top_p_candidates");
|
|
1486
|
+
|
|
1487
|
+
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
|
|
1488
|
+
struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
|
|
1489
|
+
ggml_set_name(cdf, "top_p_cdf");
|
|
1490
|
+
|
|
1491
|
+
// Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
|
|
1492
|
+
struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
|
|
1493
|
+
ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
|
|
1494
|
+
|
|
1495
|
+
struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
|
|
1496
|
+
ggml_set_name(mask, "top_p_mask");
|
|
1497
|
+
|
|
1498
|
+
// Taking the sum of the mask gives us the sum of elements after the threshold
|
|
1499
|
+
// we are interested in.
|
|
1500
|
+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
|
1501
|
+
ggml_set_name(idxf, "top_p_index_f32");
|
|
1502
|
+
|
|
1503
|
+
// prevent out-of-bounds access
|
|
1504
|
+
idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
|
|
1505
|
+
|
|
1506
|
+
// construct ones tensor to set the value in the mask
|
|
1507
|
+
struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
|
|
1508
|
+
ggml_set_name(ones, "top_p_ones");
|
|
1509
|
+
|
|
1510
|
+
// Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
|
|
1511
|
+
struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
|
|
1512
|
+
|
|
1513
|
+
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
|
|
1514
|
+
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
|
|
1515
|
+
|
|
1516
|
+
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
|
|
1517
|
+
// top_p_bias = (mask * 1e9f) - 1e9f.
|
|
1518
|
+
// So entries in the mask that we want to discard will become -1e9f, and
|
|
1519
|
+
// others will be 0 (meaning that will not effect the logits).
|
|
1520
|
+
const float large_val = 1e9f;
|
|
1521
|
+
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
|
|
1522
|
+
ggml_set_name(top_p_bias, "top_p_bias");
|
|
1523
|
+
|
|
1524
|
+
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
|
|
1525
|
+
ggml_set_name(data->logits, "top_p_logits");
|
|
1526
|
+
|
|
1527
|
+
GGML_UNUSED(gf);
|
|
1528
|
+
}
|
|
1529
|
+
|
|
745
1530
|
static struct llama_sampler_i llama_sampler_top_p_i = {
|
|
746
|
-
/* .name
|
|
747
|
-
/* .accept
|
|
748
|
-
/* .apply
|
|
749
|
-
/* .reset
|
|
750
|
-
/* .clone
|
|
751
|
-
/* .free
|
|
1531
|
+
/* .name = */ llama_sampler_top_p_name,
|
|
1532
|
+
/* .accept = */ nullptr,
|
|
1533
|
+
/* .apply = */ llama_sampler_top_p_apply,
|
|
1534
|
+
/* .reset = */ nullptr,
|
|
1535
|
+
/* .clone = */ llama_sampler_top_p_clone,
|
|
1536
|
+
/* .free = */ llama_sampler_top_p_free,
|
|
1537
|
+
/* .backend_init = */ llama_sampler_top_p_backend_init,
|
|
1538
|
+
/* .backend_accept = */ nullptr,
|
|
1539
|
+
/* .backend_apply = */ llama_sampler_top_p_backend_apply,
|
|
1540
|
+
/* .backend_set_input = */ nullptr,
|
|
752
1541
|
};
|
|
753
1542
|
|
|
754
1543
|
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
|
1544
|
+
const bool is_empty = p >= 1.0f;
|
|
1545
|
+
|
|
1546
|
+
if (is_empty) {
|
|
1547
|
+
return llama_sampler_init_empty("?top-p");
|
|
1548
|
+
}
|
|
1549
|
+
|
|
755
1550
|
return llama_sampler_init(
|
|
756
1551
|
/* .iface = */ &llama_sampler_top_p_i,
|
|
757
1552
|
/* .ctx = */ new llama_sampler_top_p {
|
|
1553
|
+
("top-p"),
|
|
758
1554
|
/* .p = */ p,
|
|
759
1555
|
/* .min_keep = */ min_keep,
|
|
1556
|
+
/* .buf_sort = */ {},
|
|
760
1557
|
}
|
|
761
1558
|
);
|
|
762
1559
|
}
|
|
763
1560
|
|
|
764
1561
|
// min-p
|
|
765
1562
|
|
|
766
|
-
struct llama_sampler_min_p {
|
|
1563
|
+
struct llama_sampler_min_p : public llama_sampler_backend {
|
|
767
1564
|
const float p;
|
|
768
1565
|
const size_t min_keep;
|
|
769
1566
|
};
|
|
770
1567
|
|
|
771
|
-
static const char * llama_sampler_min_p_name(const struct llama_sampler *
|
|
772
|
-
|
|
1568
|
+
static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
|
|
1569
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1570
|
+
return sctx->get_name();
|
|
773
1571
|
}
|
|
774
1572
|
|
|
775
1573
|
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
776
|
-
|
|
1574
|
+
auto * ctx = (llama_sampler_min_p *) smpl->ctx;
|
|
777
1575
|
|
|
778
1576
|
if (ctx->p <= 0.0f || !cur_p->size) {
|
|
779
1577
|
return;
|
|
@@ -799,7 +1597,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
|
|
|
799
1597
|
|
|
800
1598
|
// if we have enough values the operation was a success
|
|
801
1599
|
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
|
|
802
|
-
|
|
1600
|
+
std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
|
|
803
1601
|
cur_p->size = filtered_tokens.size();
|
|
804
1602
|
min_p_applied = true;
|
|
805
1603
|
}
|
|
@@ -809,10 +1607,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
|
|
|
809
1607
|
if (!min_p_applied) {
|
|
810
1608
|
// Sort the logits in descending order
|
|
811
1609
|
if (!cur_p->sorted) {
|
|
812
|
-
|
|
813
|
-
return a.logit > b.logit;
|
|
814
|
-
});
|
|
815
|
-
cur_p->sorted = true;
|
|
1610
|
+
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
|
|
816
1611
|
}
|
|
817
1612
|
|
|
818
1613
|
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
|
|
@@ -838,19 +1633,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|
|
838
1633
|
delete (llama_sampler_min_p *) smpl->ctx;
|
|
839
1634
|
}
|
|
840
1635
|
|
|
1636
|
+
static bool llama_sampler_min_p_backend_init(
|
|
1637
|
+
struct llama_sampler * smpl,
|
|
1638
|
+
ggml_backend_buffer_type_t buft) {
|
|
1639
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1640
|
+
|
|
1641
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1642
|
+
|
|
1643
|
+
sctx->init(res);
|
|
1644
|
+
|
|
1645
|
+
return res;
|
|
1646
|
+
}
|
|
1647
|
+
|
|
1648
|
+
static void llama_sampler_min_p_backend_apply(
|
|
1649
|
+
struct llama_sampler * smpl,
|
|
1650
|
+
struct ggml_context * ctx,
|
|
1651
|
+
struct ggml_cgraph * gf,
|
|
1652
|
+
struct llama_sampler_data * data) {
|
|
1653
|
+
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
|
1654
|
+
|
|
1655
|
+
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
|
1656
|
+
ggml_set_name(max_idx, "max_idx");
|
|
1657
|
+
|
|
1658
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1659
|
+
ggml_set_name(logits_rows, "logits_rows");
|
|
1660
|
+
|
|
1661
|
+
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
1662
|
+
ggml_set_name(max_logit, "max_logit");
|
|
1663
|
+
|
|
1664
|
+
// Calculate the threshold value.
|
|
1665
|
+
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
|
|
1666
|
+
ggml_set_name(threshold, "min_p_threshold");
|
|
1667
|
+
|
|
1668
|
+
// Subtract the threshold from logits.
|
|
1669
|
+
struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
|
|
1670
|
+
|
|
1671
|
+
// Create a mask where logits below the threshold are 0 (discard),
|
|
1672
|
+
// and others are 1 (keep).
|
|
1673
|
+
struct ggml_tensor * mask = ggml_step(ctx, sub);
|
|
1674
|
+
ggml_set_name(mask, "min_p_mask");
|
|
1675
|
+
|
|
1676
|
+
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
|
|
1677
|
+
// min_p_bias = (mask * 1e9f) - 1e9f.
|
|
1678
|
+
// So entries in the mask that we want to discard will become -1e9f, and
|
|
1679
|
+
// others will be 0 (meaning that will not effect the logits).
|
|
1680
|
+
const float large_val = 1e9f;
|
|
1681
|
+
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
|
|
1682
|
+
ggml_set_name(min_p_bias, "min_p_bias");
|
|
1683
|
+
|
|
1684
|
+
// Add the min_p bias to the logits.
|
|
1685
|
+
data->logits = ggml_add(ctx, data->logits, min_p_bias);
|
|
1686
|
+
ggml_set_name(data->logits, "min_p_logits");
|
|
1687
|
+
|
|
1688
|
+
GGML_UNUSED(gf);
|
|
1689
|
+
}
|
|
1690
|
+
|
|
841
1691
|
static struct llama_sampler_i llama_sampler_min_p_i = {
|
|
842
|
-
/* .name
|
|
843
|
-
/* .accept
|
|
844
|
-
/* .apply
|
|
845
|
-
/* .reset
|
|
846
|
-
/* .clone
|
|
847
|
-
/* .free
|
|
1692
|
+
/* .name = */ llama_sampler_min_p_name,
|
|
1693
|
+
/* .accept = */ nullptr,
|
|
1694
|
+
/* .apply = */ llama_sampler_min_p_apply,
|
|
1695
|
+
/* .reset = */ nullptr,
|
|
1696
|
+
/* .clone = */ llama_sampler_min_p_clone,
|
|
1697
|
+
/* .free = */ llama_sampler_min_p_free,
|
|
1698
|
+
/* .backend_init = */ llama_sampler_min_p_backend_init,
|
|
1699
|
+
/* .backend_accept = */ nullptr,
|
|
1700
|
+
/* .backend_apply = */ llama_sampler_min_p_backend_apply,
|
|
1701
|
+
/* .backend_set_input = */ nullptr,
|
|
848
1702
|
};
|
|
849
1703
|
|
|
850
1704
|
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
|
1705
|
+
const bool is_empty = (p <= 0.0f);
|
|
1706
|
+
|
|
1707
|
+
if (is_empty) {
|
|
1708
|
+
return llama_sampler_init_empty("?min-p");
|
|
1709
|
+
}
|
|
1710
|
+
|
|
851
1711
|
return llama_sampler_init(
|
|
852
1712
|
/* .iface = */ &llama_sampler_min_p_i,
|
|
853
1713
|
/* .ctx = */ new llama_sampler_min_p {
|
|
1714
|
+
("min-p"),
|
|
854
1715
|
/* .p = */ p,
|
|
855
1716
|
/* .min_keep = */ min_keep,
|
|
856
1717
|
}
|
|
@@ -869,7 +1730,7 @@ static const char * llama_sampler_typical_name(const struct llama_sampler * /*sm
|
|
|
869
1730
|
}
|
|
870
1731
|
|
|
871
1732
|
static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
872
|
-
|
|
1733
|
+
auto * ctx = (llama_sampler_typical *) smpl->ctx;
|
|
873
1734
|
|
|
874
1735
|
// Reference implementation:
|
|
875
1736
|
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
|
@@ -878,7 +1739,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
|
|
|
878
1739
|
}
|
|
879
1740
|
|
|
880
1741
|
// Compute the softmax of logits and calculate entropy
|
|
881
|
-
llama_sampler_softmax_impl(cur_p);
|
|
1742
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
882
1743
|
|
|
883
1744
|
float entropy = 0.0f;
|
|
884
1745
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
@@ -938,15 +1799,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
|
|
938
1799
|
}
|
|
939
1800
|
|
|
940
1801
|
static struct llama_sampler_i llama_sampler_typical_i = {
|
|
941
|
-
/* .name
|
|
942
|
-
/* .accept
|
|
943
|
-
/* .apply
|
|
944
|
-
/* .reset
|
|
945
|
-
/* .clone
|
|
946
|
-
/* .free
|
|
1802
|
+
/* .name = */ llama_sampler_typical_name,
|
|
1803
|
+
/* .accept = */ nullptr,
|
|
1804
|
+
/* .apply = */ llama_sampler_typical_apply,
|
|
1805
|
+
/* .reset = */ nullptr,
|
|
1806
|
+
/* .clone = */ llama_sampler_typical_clone,
|
|
1807
|
+
/* .free = */ llama_sampler_typical_free,
|
|
1808
|
+
/* .backend_init = */ nullptr,
|
|
1809
|
+
/* .backend_accept = */ nullptr,
|
|
1810
|
+
/* .backend_apply = */ nullptr,
|
|
1811
|
+
/* .backend_set_input = */ nullptr,
|
|
947
1812
|
};
|
|
948
1813
|
|
|
949
1814
|
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
1815
|
+
const bool is_empty = (p >= 1.0f);
|
|
1816
|
+
|
|
1817
|
+
if (is_empty) {
|
|
1818
|
+
return llama_sampler_init_empty("?typical");
|
|
1819
|
+
}
|
|
1820
|
+
|
|
950
1821
|
return llama_sampler_init(
|
|
951
1822
|
/* .iface = */ &llama_sampler_typical_i,
|
|
952
1823
|
/* .ctx = */ new llama_sampler_typical {
|
|
@@ -958,12 +1829,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
|
|
958
1829
|
|
|
959
1830
|
// temp
|
|
960
1831
|
|
|
961
|
-
struct llama_sampler_temp {
|
|
1832
|
+
struct llama_sampler_temp : public llama_sampler_backend {
|
|
962
1833
|
const float temp;
|
|
963
1834
|
};
|
|
964
1835
|
|
|
965
|
-
static const char * llama_sampler_temp_name(const struct llama_sampler *
|
|
966
|
-
|
|
1836
|
+
static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
|
|
1837
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1838
|
+
return sctx->get_name();
|
|
967
1839
|
}
|
|
968
1840
|
|
|
969
1841
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -981,19 +1853,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
|
|
981
1853
|
delete (llama_sampler_temp *) smpl->ctx;
|
|
982
1854
|
}
|
|
983
1855
|
|
|
1856
|
+
static void llama_sampler_backend_temp_sampling(
|
|
1857
|
+
struct ggml_context * ctx,
|
|
1858
|
+
struct ggml_cgraph * gf,
|
|
1859
|
+
struct llama_sampler_data * data,
|
|
1860
|
+
float temp) {
|
|
1861
|
+
if (temp <= 0.0f) {
|
|
1862
|
+
// Find the most probable token index.
|
|
1863
|
+
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
|
1864
|
+
ggml_set_name(max_idx, "temp_max_idx");
|
|
1865
|
+
|
|
1866
|
+
if (data->candidates) {
|
|
1867
|
+
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
|
1868
|
+
data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
|
|
1869
|
+
} else {
|
|
1870
|
+
data->candidates = max_idx;
|
|
1871
|
+
}
|
|
1872
|
+
|
|
1873
|
+
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
|
1874
|
+
data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
1875
|
+
|
|
1876
|
+
return;
|
|
1877
|
+
}
|
|
1878
|
+
|
|
1879
|
+
data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
|
|
1880
|
+
|
|
1881
|
+
GGML_UNUSED(gf);
|
|
1882
|
+
}
|
|
1883
|
+
|
|
1884
|
+
static bool llama_sampler_temp_backend_init(
|
|
1885
|
+
struct llama_sampler * smpl,
|
|
1886
|
+
ggml_backend_buffer_type_t buft) {
|
|
1887
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1888
|
+
|
|
1889
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
1890
|
+
|
|
1891
|
+
sctx->init(res);
|
|
1892
|
+
|
|
1893
|
+
return res;
|
|
1894
|
+
}
|
|
1895
|
+
|
|
1896
|
+
static void llama_sampler_temp_backend_apply(
|
|
1897
|
+
struct llama_sampler * smpl,
|
|
1898
|
+
struct ggml_context * ctx,
|
|
1899
|
+
struct ggml_cgraph * gf,
|
|
1900
|
+
struct llama_sampler_data * data) {
|
|
1901
|
+
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
|
1902
|
+
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
|
1903
|
+
}
|
|
1904
|
+
|
|
984
1905
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
|
985
|
-
/* .name
|
|
986
|
-
/* .accept
|
|
987
|
-
/* .apply
|
|
988
|
-
/* .reset
|
|
989
|
-
/* .clone
|
|
990
|
-
/* .free
|
|
1906
|
+
/* .name = */ llama_sampler_temp_name,
|
|
1907
|
+
/* .accept = */ nullptr,
|
|
1908
|
+
/* .apply = */ llama_sampler_temp_apply,
|
|
1909
|
+
/* .reset = */ nullptr,
|
|
1910
|
+
/* .clone = */ llama_sampler_temp_clone,
|
|
1911
|
+
/* .free = */ llama_sampler_temp_free,
|
|
1912
|
+
/* .backend_init = */ llama_sampler_temp_backend_init,
|
|
1913
|
+
/* .backend_accept = */ nullptr,
|
|
1914
|
+
/* .backend_apply = */ llama_sampler_temp_backend_apply,
|
|
1915
|
+
/* .backend_set_input = */ nullptr,
|
|
991
1916
|
};
|
|
992
1917
|
|
|
993
1918
|
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
1919
|
+
const bool is_empty = temp == 1.0f;
|
|
1920
|
+
|
|
1921
|
+
if (is_empty) {
|
|
1922
|
+
return llama_sampler_init_empty("?temp");
|
|
1923
|
+
}
|
|
1924
|
+
|
|
994
1925
|
return llama_sampler_init(
|
|
995
1926
|
/* .iface = */ &llama_sampler_temp_i,
|
|
996
1927
|
/* .ctx = */ new llama_sampler_temp {
|
|
1928
|
+
("temp"),
|
|
997
1929
|
/*.temp = */ temp,
|
|
998
1930
|
}
|
|
999
1931
|
);
|
|
@@ -1001,18 +1933,19 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
|
|
|
1001
1933
|
|
|
1002
1934
|
// temp-ext
|
|
1003
1935
|
|
|
1004
|
-
struct llama_sampler_temp_ext {
|
|
1936
|
+
struct llama_sampler_temp_ext : public llama_sampler_backend {
|
|
1005
1937
|
const float temp;
|
|
1006
1938
|
const float delta;
|
|
1007
1939
|
const float exponent;
|
|
1008
1940
|
};
|
|
1009
1941
|
|
|
1010
|
-
static const char * llama_sampler_temp_ext_name(const struct llama_sampler *
|
|
1011
|
-
|
|
1942
|
+
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
|
|
1943
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
1944
|
+
return sctx->get_name();
|
|
1012
1945
|
}
|
|
1013
1946
|
|
|
1014
1947
|
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1015
|
-
|
|
1948
|
+
auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
1016
1949
|
if (ctx->delta > 0) {
|
|
1017
1950
|
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
|
1018
1951
|
const float max_temp = ctx->temp + ctx->delta;
|
|
@@ -1027,7 +1960,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|
|
1027
1960
|
// Calculate maximum possible entropy
|
|
1028
1961
|
float max_entropy = -logf(1.0f / cur_p->size);
|
|
1029
1962
|
|
|
1030
|
-
llama_sampler_softmax_impl(cur_p);
|
|
1963
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1031
1964
|
|
|
1032
1965
|
// Calculate entropy of the softmax probabilities
|
|
1033
1966
|
float entropy = 0.0f;
|
|
@@ -1091,24 +2024,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|
|
1091
2024
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
|
1092
2025
|
}
|
|
1093
2026
|
|
|
2027
|
+
static bool llama_sampler_temp_ext_backend_init(
|
|
2028
|
+
struct llama_sampler * smpl,
|
|
2029
|
+
ggml_backend_buffer_type_t buft) {
|
|
2030
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
2031
|
+
|
|
2032
|
+
const bool res = llama_sampler_backend_support(smpl, buft);
|
|
2033
|
+
|
|
2034
|
+
sctx->init(res);
|
|
2035
|
+
|
|
2036
|
+
return res;
|
|
2037
|
+
}
|
|
2038
|
+
|
|
2039
|
+
static void llama_sampler_temp_ext_backend_apply(
|
|
2040
|
+
struct llama_sampler * smpl,
|
|
2041
|
+
struct ggml_context * ctx,
|
|
2042
|
+
struct ggml_cgraph * gf,
|
|
2043
|
+
struct llama_sampler_data * data) {
|
|
2044
|
+
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
|
2045
|
+
|
|
2046
|
+
// Revert to standard temperature scaling if delta or temp are non-positive.
|
|
2047
|
+
if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
|
|
2048
|
+
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
|
2049
|
+
return;
|
|
2050
|
+
}
|
|
2051
|
+
|
|
2052
|
+
// Calculate min_temp, max_temp, and max_entropy.
|
|
2053
|
+
const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
|
|
2054
|
+
const float max_temp = sctx->temp + sctx->delta;
|
|
2055
|
+
const float max_entropy = logf(data->logits->ne[0]);
|
|
2056
|
+
|
|
2057
|
+
// Calculate the probabilities.
|
|
2058
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
|
2059
|
+
ggml_set_name(probs, "temp_ext_softmax_probs");
|
|
2060
|
+
|
|
2061
|
+
// Clamp probabilities to avoid log(0) which would give -inf
|
|
2062
|
+
struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
|
|
2063
|
+
ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
|
|
2064
|
+
|
|
2065
|
+
// Calculate the entropy, entropy = -Σ(p * log(p)).
|
|
2066
|
+
struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
|
|
2067
|
+
struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
|
|
2068
|
+
struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
|
|
2069
|
+
struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
|
|
2070
|
+
ggml_set_name(log_probs, "temp_ext_log_probs");
|
|
2071
|
+
ggml_set_name(p_log_p, "temp_ext_p_log_p");
|
|
2072
|
+
ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
|
|
2073
|
+
ggml_set_name(entropy, "temp_ext_entropy");
|
|
2074
|
+
|
|
2075
|
+
// Normalize the entropy, norm_entropy = entropy / max_entropy
|
|
2076
|
+
struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
|
|
2077
|
+
ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
|
|
2078
|
+
|
|
2079
|
+
// Calculate the dynamic temperature:
|
|
2080
|
+
// dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
|
|
2081
|
+
//
|
|
2082
|
+
// Calculate powf(normalized_entropy, exponent) as
|
|
2083
|
+
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
|
|
2084
|
+
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
|
|
2085
|
+
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
|
|
2086
|
+
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
|
|
2087
|
+
// With pow_entropy computed we can now compute dyn_temp, scaling by
|
|
2088
|
+
// (max_temp - min_temp) and then adding min_temp.
|
|
2089
|
+
struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
|
|
2090
|
+
ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
|
|
2091
|
+
ggml_set_name(scaled_log, "temp_ext_scaled_log");
|
|
2092
|
+
ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
|
|
2093
|
+
ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
|
|
2094
|
+
|
|
2095
|
+
// Scale the logits by the dynamic temperature
|
|
2096
|
+
struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
|
|
2097
|
+
ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
|
|
2098
|
+
|
|
2099
|
+
data->logits = scaled_logits;
|
|
2100
|
+
}
|
|
2101
|
+
|
|
1094
2102
|
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|
1095
|
-
/* .name
|
|
1096
|
-
/* .accept
|
|
1097
|
-
/* .apply
|
|
1098
|
-
/* .reset
|
|
1099
|
-
/* .clone
|
|
1100
|
-
/* .free
|
|
2103
|
+
/* .name = */ llama_sampler_temp_ext_name,
|
|
2104
|
+
/* .accept = */ nullptr,
|
|
2105
|
+
/* .apply = */ llama_sampler_temp_ext_apply,
|
|
2106
|
+
/* .reset = */ nullptr,
|
|
2107
|
+
/* .clone = */ llama_sampler_temp_ext_clone,
|
|
2108
|
+
/* .free = */ llama_sampler_temp_ext_free,
|
|
2109
|
+
/* .backend_init = */ llama_sampler_temp_ext_backend_init,
|
|
2110
|
+
/* .backend_accept = */ nullptr,
|
|
2111
|
+
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
|
|
2112
|
+
/* .backend_set_input = */ nullptr,
|
|
1101
2113
|
};
|
|
1102
2114
|
|
|
1103
2115
|
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
|
1104
|
-
|
|
2116
|
+
const bool is_empty = temp == 1.0f && delta <= 0.0f;
|
|
2117
|
+
|
|
2118
|
+
if (is_empty) {
|
|
2119
|
+
return llama_sampler_init_empty("?temp-ext");
|
|
2120
|
+
}
|
|
2121
|
+
|
|
2122
|
+
auto * res = llama_sampler_init(
|
|
1105
2123
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
|
1106
2124
|
/* .ctx = */ new llama_sampler_temp_ext {
|
|
2125
|
+
("temp-ext"),
|
|
1107
2126
|
/* .temp = */ temp,
|
|
1108
2127
|
/* .delta = */ delta,
|
|
1109
2128
|
/* .exponent = */ exponent,
|
|
1110
2129
|
}
|
|
1111
2130
|
);
|
|
2131
|
+
|
|
2132
|
+
return res;
|
|
1112
2133
|
}
|
|
1113
2134
|
|
|
1114
2135
|
// xtc
|
|
@@ -1139,17 +2160,20 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|
|
1139
2160
|
|
|
1140
2161
|
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
|
|
1141
2162
|
float chance = distribution(ctx->rng);
|
|
1142
|
-
if (chance > ctx->probability)
|
|
2163
|
+
if (chance > ctx->probability) {
|
|
2164
|
+
return;
|
|
2165
|
+
}
|
|
1143
2166
|
|
|
1144
|
-
|
|
1145
|
-
llama_sampler_softmax_impl(cur_p);
|
|
2167
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1146
2168
|
|
|
1147
2169
|
int pos_last = 0;
|
|
1148
2170
|
|
|
1149
2171
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1150
2172
|
if (cur_p->data[i].p >= ctx->threshold) {
|
|
1151
2173
|
pos_last = i;
|
|
1152
|
-
} else
|
|
2174
|
+
} else {
|
|
2175
|
+
break;
|
|
2176
|
+
}
|
|
1153
2177
|
}
|
|
1154
2178
|
|
|
1155
2179
|
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
|
|
@@ -1183,16 +2207,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
|
|
1183
2207
|
}
|
|
1184
2208
|
|
|
1185
2209
|
static struct llama_sampler_i llama_sampler_xtc_i = {
|
|
1186
|
-
/* .name
|
|
1187
|
-
/* .accept
|
|
1188
|
-
/* .apply
|
|
1189
|
-
/* .reset
|
|
1190
|
-
/* .clone
|
|
1191
|
-
/* .free
|
|
2210
|
+
/* .name = */ llama_sampler_xtc_name,
|
|
2211
|
+
/* .accept = */ nullptr,
|
|
2212
|
+
/* .apply = */ llama_sample_xtc_apply,
|
|
2213
|
+
/* .reset = */ llama_sampler_xtc_reset,
|
|
2214
|
+
/* .clone = */ llama_sampler_xtc_clone,
|
|
2215
|
+
/* .free = */ llama_sampler_xtc_free,
|
|
2216
|
+
/* .backend_init = */ nullptr,
|
|
2217
|
+
/* .backend_accept = */ nullptr,
|
|
2218
|
+
/* .backend_apply = */ nullptr,
|
|
2219
|
+
/* .backend_set_input = */ nullptr,
|
|
1192
2220
|
};
|
|
1193
2221
|
|
|
1194
2222
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
|
1195
|
-
|
|
2223
|
+
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
|
2224
|
+
|
|
2225
|
+
if (is_empty) {
|
|
2226
|
+
return llama_sampler_init_empty("?xtc");
|
|
2227
|
+
}
|
|
2228
|
+
|
|
2229
|
+
const auto seed_cur = get_rng_seed(seed);
|
|
2230
|
+
|
|
1196
2231
|
return llama_sampler_init(
|
|
1197
2232
|
/* .iface = */ &llama_sampler_xtc_i,
|
|
1198
2233
|
/* .ctx = */ new llama_sampler_xtc {
|
|
@@ -1221,7 +2256,7 @@ struct llama_sampler_mirostat {
|
|
|
1221
2256
|
|
|
1222
2257
|
float mu;
|
|
1223
2258
|
|
|
1224
|
-
std::mt19937
|
|
2259
|
+
std::mt19937 rng;
|
|
1225
2260
|
};
|
|
1226
2261
|
|
|
1227
2262
|
static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
|
|
@@ -1231,7 +2266,7 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s
|
|
|
1231
2266
|
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1232
2267
|
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
|
|
1233
2268
|
|
|
1234
|
-
llama_sampler_softmax_impl(cur_p);
|
|
2269
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1235
2270
|
|
|
1236
2271
|
// Estimate s_hat using the most probable m tokens
|
|
1237
2272
|
float s_hat = 0.0;
|
|
@@ -1250,7 +2285,8 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke
|
|
|
1250
2285
|
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
|
|
1251
2286
|
|
|
1252
2287
|
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
|
1253
|
-
|
|
2288
|
+
|
|
2289
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1254
2290
|
|
|
1255
2291
|
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
|
1256
2292
|
|
|
@@ -1290,16 +2326,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|
|
1290
2326
|
}
|
|
1291
2327
|
|
|
1292
2328
|
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|
1293
|
-
/* .name
|
|
1294
|
-
/* .accept
|
|
1295
|
-
/* .apply
|
|
1296
|
-
/* .reset
|
|
1297
|
-
/* .clone
|
|
1298
|
-
/* .free
|
|
2329
|
+
/* .name = */ llama_sampler_mirostat_name,
|
|
2330
|
+
/* .accept = */ nullptr,
|
|
2331
|
+
/* .apply = */ llama_sampler_mirostat_apply,
|
|
2332
|
+
/* .reset = */ llama_sampler_mirostat_reset,
|
|
2333
|
+
/* .clone = */ llama_sampler_mirostat_clone,
|
|
2334
|
+
/* .free = */ llama_sampler_mirostat_free,
|
|
2335
|
+
/* .backend_init = */ nullptr,
|
|
2336
|
+
/* .backend_accept = */ nullptr,
|
|
2337
|
+
/* .backend_apply = */ nullptr,
|
|
2338
|
+
/* .backend_set_input = */ nullptr,
|
|
1299
2339
|
};
|
|
1300
2340
|
|
|
1301
2341
|
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
|
1302
|
-
auto seed_cur = get_rng_seed(seed);
|
|
2342
|
+
const auto seed_cur = get_rng_seed(seed);
|
|
2343
|
+
|
|
1303
2344
|
return llama_sampler_init(
|
|
1304
2345
|
/* .iface = */ &llama_sampler_mirostat_i,
|
|
1305
2346
|
/* .ctx = */ new llama_sampler_mirostat {
|
|
@@ -1336,7 +2377,7 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler *
|
|
|
1336
2377
|
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1337
2378
|
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
|
|
1338
2379
|
|
|
1339
|
-
llama_sampler_softmax_impl(cur_p);
|
|
2380
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1340
2381
|
|
|
1341
2382
|
// Truncate the words with surprise values greater than mu
|
|
1342
2383
|
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
|
@@ -1348,7 +2389,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
|
|
|
1348
2389
|
}
|
|
1349
2390
|
|
|
1350
2391
|
// Normalize the probabilities of the remaining words
|
|
1351
|
-
llama_sampler_softmax_impl(cur_p);
|
|
2392
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1352
2393
|
|
|
1353
2394
|
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
|
1354
2395
|
|
|
@@ -1389,12 +2430,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
|
|
1389
2430
|
}
|
|
1390
2431
|
|
|
1391
2432
|
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|
1392
|
-
/* .name
|
|
1393
|
-
/* .accept
|
|
1394
|
-
/* .apply
|
|
1395
|
-
/* .reset
|
|
1396
|
-
/* .clone
|
|
1397
|
-
/* .free
|
|
2433
|
+
/* .name = */ llama_sampler_mirostat_v2_name,
|
|
2434
|
+
/* .accept = */ nullptr,
|
|
2435
|
+
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
|
2436
|
+
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
|
2437
|
+
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
|
2438
|
+
/* .free = */ llama_sampler_mirostat_v2_free,
|
|
2439
|
+
/* .backend_init = */ nullptr,
|
|
2440
|
+
/* .backend_accept = */ nullptr,
|
|
2441
|
+
/* .backend_apply = */ nullptr,
|
|
2442
|
+
/* .backend_set_input = */ nullptr,
|
|
1398
2443
|
};
|
|
1399
2444
|
|
|
1400
2445
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
|
@@ -1506,12 +2551,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
|
|
1506
2551
|
}
|
|
1507
2552
|
|
|
1508
2553
|
static struct llama_sampler_i llama_sampler_grammar_i = {
|
|
1509
|
-
/* .name
|
|
1510
|
-
/* .accept
|
|
1511
|
-
/* .apply
|
|
1512
|
-
/* .reset
|
|
1513
|
-
/* .clone
|
|
1514
|
-
/* .free
|
|
2554
|
+
/* .name = */ llama_sampler_grammar_name,
|
|
2555
|
+
/* .accept = */ llama_sampler_grammar_accept_impl,
|
|
2556
|
+
/* .apply = */ llama_sampler_grammar_apply,
|
|
2557
|
+
/* .reset = */ llama_sampler_grammar_reset,
|
|
2558
|
+
/* .clone = */ llama_sampler_grammar_clone,
|
|
2559
|
+
/* .free = */ llama_sampler_grammar_free,
|
|
2560
|
+
/* .backend_init = */ nullptr,
|
|
2561
|
+
/* .backend_accept = */ nullptr,
|
|
2562
|
+
/* .backend_apply = */ nullptr,
|
|
2563
|
+
/* .backend_set_input = */ nullptr,
|
|
1515
2564
|
};
|
|
1516
2565
|
|
|
1517
2566
|
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
@@ -1528,10 +2577,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
1528
2577
|
auto * ctx = new llama_sampler_grammar;
|
|
1529
2578
|
|
|
1530
2579
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
2580
|
+
std::string trigger_pattern;
|
|
2581
|
+
llama_grammar * grammar = nullptr;
|
|
1531
2582
|
// TODO: remove trigger_words support.
|
|
1532
2583
|
if (trigger_words != nullptr && num_trigger_words > 0) {
|
|
1533
2584
|
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
|
1534
|
-
|
|
2585
|
+
trigger_pattern = "[\\s\\S]*?(";
|
|
1535
2586
|
for (size_t i = 0; i < num_trigger_words; ++i) {
|
|
1536
2587
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
|
1537
2588
|
if (i > 0) {
|
|
@@ -1540,15 +2591,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
1540
2591
|
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
|
1541
2592
|
}
|
|
1542
2593
|
trigger_pattern += ")[\\s\\S]*";
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
2594
|
+
|
|
2595
|
+
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
|
2596
|
+
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
|
2597
|
+
} else {
|
|
2598
|
+
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
|
1546
2599
|
}
|
|
1547
2600
|
*ctx = {
|
|
1548
2601
|
/* .vocab = */ vocab,
|
|
1549
2602
|
/* .grammar_str = */ grammar_str,
|
|
1550
2603
|
/* .grammar_root = */ grammar_root,
|
|
1551
|
-
/* .grammar = */
|
|
2604
|
+
/* .grammar = */ grammar,
|
|
1552
2605
|
};
|
|
1553
2606
|
if (!ctx->grammar) {
|
|
1554
2607
|
delete ctx;
|
|
@@ -1709,12 +2762,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
|
|
1709
2762
|
}
|
|
1710
2763
|
|
|
1711
2764
|
static struct llama_sampler_i llama_sampler_penalties_i = {
|
|
1712
|
-
/* .name
|
|
1713
|
-
/* .accept
|
|
1714
|
-
/* .apply
|
|
1715
|
-
/* .reset
|
|
1716
|
-
/* .clone
|
|
1717
|
-
/* .free
|
|
2765
|
+
/* .name = */ llama_sampler_penalties_name,
|
|
2766
|
+
/* .accept = */ llama_sampler_penalties_accept,
|
|
2767
|
+
/* .apply = */ llama_sampler_penalties_apply,
|
|
2768
|
+
/* .reset = */ llama_sampler_penalties_reset,
|
|
2769
|
+
/* .clone = */ llama_sampler_penalties_clone,
|
|
2770
|
+
/* .free = */ llama_sampler_penalties_free,
|
|
2771
|
+
/* .backend_init = */ nullptr,
|
|
2772
|
+
/* .backend_accept = */ nullptr,
|
|
2773
|
+
/* .backend_apply = */ nullptr,
|
|
2774
|
+
/* .backend_set_input = */ nullptr,
|
|
1718
2775
|
};
|
|
1719
2776
|
|
|
1720
2777
|
struct llama_sampler * llama_sampler_init_penalties(
|
|
@@ -1724,6 +2781,12 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
|
1724
2781
|
float penalty_present) {
|
|
1725
2782
|
penalty_last_n = std::max(penalty_last_n, 0);
|
|
1726
2783
|
|
|
2784
|
+
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
|
2785
|
+
|
|
2786
|
+
if (is_empty) {
|
|
2787
|
+
return llama_sampler_init_empty("?penalties");
|
|
2788
|
+
}
|
|
2789
|
+
|
|
1727
2790
|
return llama_sampler_init(
|
|
1728
2791
|
/* .iface = */ &llama_sampler_penalties_i,
|
|
1729
2792
|
/* .ctx = */ new llama_sampler_penalties {
|
|
@@ -1748,7 +2811,7 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
|
|
|
1748
2811
|
}
|
|
1749
2812
|
|
|
1750
2813
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
1751
|
-
|
|
2814
|
+
auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
|
1752
2815
|
|
|
1753
2816
|
if (ctx->n <= 0.0f || cur_p->size <= 1) {
|
|
1754
2817
|
return;
|
|
@@ -1761,9 +2824,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|
|
1761
2824
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1762
2825
|
// Only count non-negative infinity values
|
|
1763
2826
|
if (cur_p->data[i].logit != -INFINITY) {
|
|
1764
|
-
|
|
1765
|
-
max = cur_p->data[i].logit;
|
|
1766
|
-
}
|
|
2827
|
+
max = std::max(max, cur_p->data[i].logit);
|
|
1767
2828
|
logits_sum += cur_p->data[i].logit;
|
|
1768
2829
|
valid_count++;
|
|
1769
2830
|
}
|
|
@@ -1780,13 +2841,14 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|
|
1780
2841
|
}
|
|
1781
2842
|
float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
|
|
1782
2843
|
|
|
1783
|
-
//apply mask
|
|
2844
|
+
// apply mask
|
|
1784
2845
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
1785
2846
|
if (cur_p->data[i].logit < max - (ctx->n * std)) {
|
|
1786
2847
|
cur_p->data[i].logit = -INFINITY;
|
|
1787
2848
|
}
|
|
1788
2849
|
}
|
|
1789
|
-
|
|
2850
|
+
|
|
2851
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
1790
2852
|
}
|
|
1791
2853
|
|
|
1792
2854
|
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
|
@@ -1799,15 +2861,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
|
|
1799
2861
|
}
|
|
1800
2862
|
|
|
1801
2863
|
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
|
1802
|
-
/* .name
|
|
1803
|
-
/* .accept
|
|
1804
|
-
/* .apply
|
|
1805
|
-
/* .reset
|
|
1806
|
-
/* .clone
|
|
1807
|
-
/* .free
|
|
2864
|
+
/* .name = */ llama_sampler_top_n_sigma_name,
|
|
2865
|
+
/* .accept = */ nullptr,
|
|
2866
|
+
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
|
2867
|
+
/* .reset = */ nullptr,
|
|
2868
|
+
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
|
2869
|
+
/* .free = */ llama_sampler_top_n_sigma_free,
|
|
2870
|
+
/* .backend_init = */ nullptr,
|
|
2871
|
+
/* .backend_accept = */ nullptr,
|
|
2872
|
+
/* .backend_apply = */ nullptr,
|
|
2873
|
+
/* .backend_set_input = */ nullptr,
|
|
1808
2874
|
};
|
|
1809
2875
|
|
|
1810
2876
|
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
|
2877
|
+
const bool is_empty = (n <= 0.0f);
|
|
2878
|
+
|
|
2879
|
+
if (is_empty) {
|
|
2880
|
+
return llama_sampler_init_empty("?top-n-sigma");
|
|
2881
|
+
}
|
|
2882
|
+
|
|
1811
2883
|
return llama_sampler_init(
|
|
1812
2884
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
|
1813
2885
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
|
@@ -1991,7 +3063,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
|
|
|
1991
3063
|
|
|
1992
3064
|
{
|
|
1993
3065
|
const int last = last_n_repeat - 1;
|
|
1994
|
-
|
|
3066
|
+
|
|
3067
|
+
int rt = 0;
|
|
3068
|
+
int lt = 0;
|
|
1995
3069
|
|
|
1996
3070
|
for (int k = 1; k < last_n_repeat; ++k) {
|
|
1997
3071
|
if (k > rt) {
|
|
@@ -2127,22 +3201,30 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
|
|
2127
3201
|
}
|
|
2128
3202
|
|
|
2129
3203
|
static struct llama_sampler_i llama_sampler_dry_i = {
|
|
2130
|
-
/* .name
|
|
2131
|
-
/* .accept
|
|
2132
|
-
/* .apply
|
|
2133
|
-
/* .reset
|
|
2134
|
-
/* .clone
|
|
2135
|
-
/* .free
|
|
3204
|
+
/* .name = */ llama_sampler_dry_name,
|
|
3205
|
+
/* .accept = */ llama_sampler_dry_accept,
|
|
3206
|
+
/* .apply = */ llama_sampler_dry_apply,
|
|
3207
|
+
/* .reset = */ llama_sampler_dry_reset,
|
|
3208
|
+
/* .clone = */ llama_sampler_dry_clone,
|
|
3209
|
+
/* .free = */ llama_sampler_dry_free,
|
|
3210
|
+
/* .backend_init = */ nullptr,
|
|
3211
|
+
/* .backend_accept = */ nullptr,
|
|
3212
|
+
/* .backend_apply = */ nullptr,
|
|
3213
|
+
/* .backend_set_input = */ nullptr,
|
|
2136
3214
|
};
|
|
2137
3215
|
|
|
2138
|
-
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t
|
|
2139
|
-
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ?
|
|
3216
|
+
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
|
3217
|
+
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
|
|
2140
3218
|
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
|
2141
3219
|
const int MAX_CHAR_LEN = 40;
|
|
2142
3220
|
const int MAX_SEQ_LEN = 20;
|
|
2143
3221
|
|
|
2144
3222
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
|
2145
3223
|
|
|
3224
|
+
if (!dry_enabled) {
|
|
3225
|
+
return llama_sampler_init_empty("?dry");
|
|
3226
|
+
}
|
|
3227
|
+
|
|
2146
3228
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
|
2147
3229
|
// Process sequence breakers
|
|
2148
3230
|
for (size_t i = 0; i < num_breakers; ++i) {
|
|
@@ -2169,7 +3251,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|
|
2169
3251
|
return llama_sampler_init(
|
|
2170
3252
|
/* .iface = */ &llama_sampler_dry_i,
|
|
2171
3253
|
/* .ctx = */ new llama_sampler_dry {
|
|
2172
|
-
/* .total_context_size = */
|
|
3254
|
+
/* .total_context_size = */ n_ctx_train,
|
|
2173
3255
|
/* .dry_multiplier = */ dry_multiplier,
|
|
2174
3256
|
/* .dry_base = */ dry_base,
|
|
2175
3257
|
/* .dry_allowed_length = */ dry_allowed_length,
|
|
@@ -2213,16 +3295,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
|
|
2213
3295
|
|
|
2214
3296
|
// logit-bias
|
|
2215
3297
|
|
|
2216
|
-
struct llama_sampler_logit_bias {
|
|
3298
|
+
struct llama_sampler_logit_bias : public llama_sampler_backend {
|
|
2217
3299
|
const int32_t n_vocab;
|
|
2218
3300
|
|
|
2219
3301
|
const std::vector<llama_logit_bias> logit_bias;
|
|
2220
3302
|
|
|
2221
3303
|
std::vector<llama_logit_bias> to_search;
|
|
3304
|
+
|
|
3305
|
+
struct ggml_tensor * inp_logit_bias;
|
|
3306
|
+
struct ggml_tensor * inp_logit_idxs;
|
|
3307
|
+
|
|
3308
|
+
ggml_context_ptr inp_ctx;
|
|
3309
|
+
ggml_backend_buffer_ptr inp_buf;
|
|
2222
3310
|
};
|
|
2223
3311
|
|
|
2224
|
-
static const char * llama_sampler_logit_bias_name(const struct llama_sampler *
|
|
2225
|
-
|
|
3312
|
+
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
|
|
3313
|
+
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3314
|
+
return ctx->get_name();
|
|
2226
3315
|
}
|
|
2227
3316
|
|
|
2228
3317
|
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
@@ -2267,25 +3356,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
|
|
2267
3356
|
delete (llama_sampler_logit_bias *) smpl->ctx;
|
|
2268
3357
|
}
|
|
2269
3358
|
|
|
3359
|
+
static void llama_sampler_logit_bias_backend_apply(
|
|
3360
|
+
struct llama_sampler * smpl,
|
|
3361
|
+
struct ggml_context * ctx,
|
|
3362
|
+
struct ggml_cgraph * gf,
|
|
3363
|
+
struct llama_sampler_data * data) {
|
|
3364
|
+
GGML_UNUSED(gf);
|
|
3365
|
+
GGML_UNUSED(ctx);
|
|
3366
|
+
|
|
3367
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3368
|
+
if (sctx->logit_bias.empty()) {
|
|
3369
|
+
return;
|
|
3370
|
+
}
|
|
3371
|
+
|
|
3372
|
+
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
|
|
3373
|
+
|
|
3374
|
+
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
|
|
3375
|
+
cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
|
|
3376
|
+
cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
|
|
3377
|
+
|
|
3378
|
+
data->logits = ggml_add(ctx, data->logits, cur);
|
|
3379
|
+
}
|
|
3380
|
+
|
|
3381
|
+
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
|
|
3382
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3383
|
+
if (sctx->logit_bias.empty()) {
|
|
3384
|
+
return;
|
|
3385
|
+
}
|
|
3386
|
+
|
|
3387
|
+
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
|
3388
|
+
GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
|
|
3389
|
+
|
|
3390
|
+
const size_t n = sctx->logit_bias.size();
|
|
3391
|
+
|
|
3392
|
+
std::vector<float> data_logit_bias(n, 0.0f);
|
|
3393
|
+
std::vector<int32_t> data_logit_idxs(n, 0);
|
|
3394
|
+
for (size_t i = 0; i < n; ++i) {
|
|
3395
|
+
const auto & lb = sctx->logit_bias[i];
|
|
3396
|
+
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
|
3397
|
+
data_logit_bias[i] = lb.bias;
|
|
3398
|
+
data_logit_idxs[i] = lb.token;
|
|
3399
|
+
}
|
|
3400
|
+
|
|
3401
|
+
ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
|
3402
|
+
ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
|
|
3403
|
+
}
|
|
3404
|
+
|
|
3405
|
+
static bool llama_sampler_logit_bias_backend_init(
|
|
3406
|
+
struct llama_sampler * smpl,
|
|
3407
|
+
ggml_backend_buffer_type_t buft) {
|
|
3408
|
+
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
|
3409
|
+
|
|
3410
|
+
sctx->init(true);
|
|
3411
|
+
|
|
3412
|
+
if (sctx->logit_bias.empty()) {
|
|
3413
|
+
return true;
|
|
3414
|
+
}
|
|
3415
|
+
|
|
3416
|
+
ggml_init_params params = {
|
|
3417
|
+
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
|
3418
|
+
/*.mem_buffer =*/ nullptr,
|
|
3419
|
+
/*.no_alloc =*/ true,
|
|
3420
|
+
};
|
|
3421
|
+
|
|
3422
|
+
sctx->inp_ctx.reset(ggml_init(params));
|
|
3423
|
+
|
|
3424
|
+
const size_t n = sctx->logit_bias.size();
|
|
3425
|
+
|
|
3426
|
+
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
|
|
3427
|
+
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
|
|
3428
|
+
ggml_set_input(sctx->inp_logit_bias);
|
|
3429
|
+
|
|
3430
|
+
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
|
|
3431
|
+
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
|
|
3432
|
+
ggml_set_input(sctx->inp_logit_idxs);
|
|
3433
|
+
|
|
3434
|
+
// Allocate all tensors from our context to the backend
|
|
3435
|
+
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
|
3436
|
+
|
|
3437
|
+
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
|
|
3438
|
+
|
|
3439
|
+
return true;
|
|
3440
|
+
}
|
|
3441
|
+
|
|
2270
3442
|
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
|
2271
|
-
/* .name
|
|
2272
|
-
/* .accept
|
|
2273
|
-
/* .apply
|
|
2274
|
-
/* .reset
|
|
2275
|
-
/* .clone
|
|
2276
|
-
/* .free
|
|
3443
|
+
/* .name = */ llama_sampler_logit_bias_name,
|
|
3444
|
+
/* .accept = */ nullptr,
|
|
3445
|
+
/* .apply = */ llama_sampler_logit_bias_apply,
|
|
3446
|
+
/* .reset = */ nullptr,
|
|
3447
|
+
/* .clone = */ llama_sampler_logit_bias_clone,
|
|
3448
|
+
/* .free = */ llama_sampler_logit_bias_free,
|
|
3449
|
+
/* .backend_init = */ llama_sampler_logit_bias_backend_init,
|
|
3450
|
+
/* .backend_accept = */ nullptr,
|
|
3451
|
+
/* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
|
|
3452
|
+
/* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
|
|
2277
3453
|
};
|
|
2278
3454
|
|
|
2279
3455
|
struct llama_sampler * llama_sampler_init_logit_bias(
|
|
2280
3456
|
int32_t n_vocab,
|
|
2281
3457
|
int32_t n_logit_bias,
|
|
2282
3458
|
const llama_logit_bias * logit_bias) {
|
|
3459
|
+
const bool is_empty = n_logit_bias <= 0;
|
|
3460
|
+
|
|
3461
|
+
if (is_empty) {
|
|
3462
|
+
return llama_sampler_init_empty("?logit-bias");
|
|
3463
|
+
}
|
|
3464
|
+
|
|
2283
3465
|
return llama_sampler_init(
|
|
2284
3466
|
/* .iface = */ &llama_sampler_logit_bias_i,
|
|
2285
3467
|
/* .ctx = */ new llama_sampler_logit_bias {
|
|
2286
|
-
|
|
2287
|
-
/* .
|
|
2288
|
-
/* .
|
|
3468
|
+
("logit-bias"),
|
|
3469
|
+
/* .n_vocab = */ n_vocab,
|
|
3470
|
+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
|
3471
|
+
/* .to_search = */ {},
|
|
3472
|
+
/* .inp_logit_bias = */ nullptr,
|
|
3473
|
+
/* .inp_logit_idxs = */ nullptr,
|
|
3474
|
+
/* .inp_ctx = */ nullptr,
|
|
3475
|
+
/* .inp_buf = */ nullptr,
|
|
2289
3476
|
}
|
|
2290
3477
|
);
|
|
2291
3478
|
}
|
|
@@ -2308,7 +3495,7 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp
|
|
|
2308
3495
|
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
2309
3496
|
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
|
2310
3497
|
|
|
2311
|
-
llama_sampler_softmax_impl(cur_p);
|
|
3498
|
+
llama_sampler_softmax_impl(cur_p, true);
|
|
2312
3499
|
|
|
2313
3500
|
#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
|
2314
3501
|
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
|
@@ -2441,8 +3628,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|
|
2441
3628
|
if (n_non_eog == 0) {
|
|
2442
3629
|
cur_p->size = 1;
|
|
2443
3630
|
cur_p->data[0].id = ctx->vocab->token_eot();
|
|
3631
|
+
if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
|
|
3632
|
+
cur_p->data[0].id = ctx->vocab->token_eos();
|
|
3633
|
+
}
|
|
2444
3634
|
cur_p->data[0].logit = 1.0f;
|
|
2445
3635
|
|
|
3636
|
+
GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
|
|
3637
|
+
|
|
2446
3638
|
return;
|
|
2447
3639
|
}
|
|
2448
3640
|
|
|
@@ -2493,12 +3685,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
|
|
2493
3685
|
}
|
|
2494
3686
|
|
|
2495
3687
|
static struct llama_sampler_i llama_sampler_infill_i = {
|
|
2496
|
-
/* .name
|
|
2497
|
-
/* .accept
|
|
2498
|
-
/* .apply
|
|
2499
|
-
/* .reset
|
|
2500
|
-
/* .clone
|
|
2501
|
-
/* .free
|
|
3688
|
+
/* .name = */ llama_sampler_infill_name,
|
|
3689
|
+
/* .accept = */ nullptr,
|
|
3690
|
+
/* .apply = */ llama_sampler_infill_apply,
|
|
3691
|
+
/* .reset = */ nullptr,
|
|
3692
|
+
/* .clone = */ llama_sampler_infill_clone,
|
|
3693
|
+
/* .free = */ llama_sampler_infill_free,
|
|
3694
|
+
/* .backend_apply = */ nullptr,
|
|
3695
|
+
/* .backend_accept = */ nullptr,
|
|
3696
|
+
/* .backend_set_input = */ nullptr,
|
|
3697
|
+
/* .backend_init = */ nullptr,
|
|
2502
3698
|
};
|
|
2503
3699
|
|
|
2504
3700
|
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
|
@@ -2530,7 +3726,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|
|
2530
3726
|
if (smpl->iface == &llama_sampler_chain_i) {
|
|
2531
3727
|
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
|
2532
3728
|
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
|
2533
|
-
const uint32_t seed = llama_sampler_get_seed(
|
|
3729
|
+
const uint32_t seed = llama_sampler_get_seed(it->ptr);
|
|
2534
3730
|
if (seed != LLAMA_DEFAULT_SEED) {
|
|
2535
3731
|
return seed;
|
|
2536
3732
|
}
|
|
@@ -2560,8 +3756,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
|
|
|
2560
3756
|
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
|
2561
3757
|
const auto data = llama_perf_sampler(chain);
|
|
2562
3758
|
|
|
2563
|
-
LLAMA_LOG_INFO("%s:
|
|
2564
|
-
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
|
3759
|
+
LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
|
|
2565
3760
|
}
|
|
2566
3761
|
|
|
2567
3762
|
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
@@ -2571,5 +3766,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
|
|
2571
3766
|
|
|
2572
3767
|
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
|
2573
3768
|
|
|
2574
|
-
ctx->t_sample_us =
|
|
3769
|
+
ctx->t_sample_us = 0;
|
|
3770
|
+
ctx->n_sample = 0;
|
|
2575
3771
|
}
|