whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
#include "llama-context.h"
|
|
2
2
|
|
|
3
|
+
#include "llama-arch.h"
|
|
3
4
|
#include "llama-impl.h"
|
|
4
5
|
#include "llama-batch.h"
|
|
5
6
|
#include "llama-io.h"
|
|
@@ -8,6 +9,7 @@
|
|
|
8
9
|
#include "llama-model.h"
|
|
9
10
|
|
|
10
11
|
#include <cinttypes>
|
|
12
|
+
#include <cmath>
|
|
11
13
|
#include <cstring>
|
|
12
14
|
#include <limits>
|
|
13
15
|
#include <stdexcept>
|
|
@@ -21,6 +23,8 @@ llama_context::llama_context(
|
|
|
21
23
|
llama_context_params params) :
|
|
22
24
|
model(model),
|
|
23
25
|
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
|
26
|
+
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
|
27
|
+
// may need to be backend-dependent
|
|
24
28
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
25
29
|
|
|
26
30
|
t_start_us = model.t_start_us;
|
|
@@ -56,6 +60,25 @@ llama_context::llama_context(
|
|
|
56
60
|
cparams.cb_eval = params.cb_eval;
|
|
57
61
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
58
62
|
|
|
63
|
+
// Initialize backend samplers here so they are part of the sampling graph
|
|
64
|
+
// before the reserve passes run later in this function. This avoids a later
|
|
65
|
+
// re-reserve when graph nodes change.
|
|
66
|
+
if (params.samplers != nullptr && params.n_samplers > 0) {
|
|
67
|
+
for (size_t i = 0; i < params.n_samplers; ++i) {
|
|
68
|
+
const auto & config = params.samplers[i];
|
|
69
|
+
|
|
70
|
+
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
|
|
71
|
+
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if (set_sampler(config.seq_id, config.sampler)) {
|
|
75
|
+
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
|
76
|
+
|
|
77
|
+
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
59
82
|
auto rope_scaling_type = params.rope_scaling_type;
|
|
60
83
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
|
61
84
|
rope_scaling_type = hparams.rope_scaling_type_train;
|
|
@@ -69,6 +92,43 @@ llama_context::llama_context(
|
|
|
69
92
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
70
93
|
}
|
|
71
94
|
|
|
95
|
+
if (cparams.yarn_ext_factor != 0) {
|
|
96
|
+
static auto get_mscale = [](float scale, float mscale) {
|
|
97
|
+
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
|
98
|
+
};
|
|
99
|
+
|
|
100
|
+
const float factor = 1.0f / cparams.rope_freq_scale;
|
|
101
|
+
|
|
102
|
+
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
|
|
103
|
+
if (hparams.rope_yarn_log_mul != 0.0f) {
|
|
104
|
+
// note: here we assume `mscale == 1.0f`
|
|
105
|
+
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
|
|
106
|
+
float mscale = 1.0f;
|
|
107
|
+
const float mscale_all_dims = hparams.rope_yarn_log_mul;
|
|
108
|
+
|
|
109
|
+
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
|
110
|
+
// special-case DEEPSEEK v2:
|
|
111
|
+
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
|
|
112
|
+
if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
|
|
113
|
+
mscale = mscale_all_dims;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
|
117
|
+
|
|
118
|
+
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
|
|
119
|
+
__func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
|
|
120
|
+
} else {
|
|
121
|
+
cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
|
|
125
|
+
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
|
|
126
|
+
//
|
|
127
|
+
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
|
|
128
|
+
// https://github.com/ggml-org/llama.cpp/pull/17945
|
|
129
|
+
cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
|
|
130
|
+
}
|
|
131
|
+
|
|
72
132
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
73
133
|
|
|
74
134
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
@@ -90,14 +150,6 @@ llama_context::llama_context(
|
|
|
90
150
|
// with causal attention, the batch size is limited by the context size
|
|
91
151
|
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
|
92
152
|
|
|
93
|
-
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
|
94
|
-
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
|
95
|
-
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
|
96
|
-
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
|
97
|
-
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
|
98
|
-
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
|
99
|
-
cparams.n_batch = GGML_KQ_MASK_PAD;
|
|
100
|
-
}
|
|
101
153
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
102
154
|
|
|
103
155
|
cparams.op_offload = params.op_offload;
|
|
@@ -112,11 +164,28 @@ llama_context::llama_context(
|
|
|
112
164
|
}
|
|
113
165
|
}
|
|
114
166
|
|
|
115
|
-
|
|
167
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
|
168
|
+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
|
|
169
|
+
|
|
170
|
+
if (cparams.kv_unified) {
|
|
171
|
+
cparams.n_ctx_seq = cparams.n_ctx;
|
|
172
|
+
} else {
|
|
173
|
+
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
174
|
+
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
|
|
175
|
+
|
|
176
|
+
if (cparams.n_ctx_seq == 0) {
|
|
177
|
+
throw std::runtime_error("n_ctx_seq == 0");
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
|
181
|
+
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
|
182
|
+
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
116
185
|
|
|
117
186
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
|
118
187
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
119
|
-
LLAMA_LOG_INFO("%s:
|
|
188
|
+
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
|
120
189
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
|
121
190
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
|
122
191
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
|
@@ -125,14 +194,14 @@ llama_context::llama_context(
|
|
|
125
194
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
|
126
195
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
|
127
196
|
|
|
128
|
-
if (
|
|
129
|
-
LLAMA_LOG_WARN("%s:
|
|
130
|
-
__func__,
|
|
197
|
+
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
|
198
|
+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
|
199
|
+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
|
131
200
|
}
|
|
132
201
|
|
|
133
|
-
if (
|
|
134
|
-
LLAMA_LOG_WARN("%s:
|
|
135
|
-
__func__,
|
|
202
|
+
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
|
203
|
+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
|
204
|
+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
|
136
205
|
}
|
|
137
206
|
|
|
138
207
|
if (!hparams.vocab_only) {
|
|
@@ -181,7 +250,10 @@ llama_context::llama_context(
|
|
|
181
250
|
// graph outputs buffer
|
|
182
251
|
{
|
|
183
252
|
// resized during inference when a batch uses more outputs
|
|
184
|
-
|
|
253
|
+
// Create a dummy batch for initialization.
|
|
254
|
+
llama_batch dummy_batch = {};
|
|
255
|
+
dummy_batch.n_tokens = 0;
|
|
256
|
+
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
|
|
185
257
|
throw std::runtime_error("failed to reserve initial output buffer");
|
|
186
258
|
}
|
|
187
259
|
|
|
@@ -208,6 +280,7 @@ llama_context::llama_context(
|
|
|
208
280
|
|
|
209
281
|
backend_buft.clear();
|
|
210
282
|
backend_ptrs.clear();
|
|
283
|
+
backend_buf_exp_size.clear();
|
|
211
284
|
|
|
212
285
|
for (auto & backend : backends) {
|
|
213
286
|
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
|
@@ -224,11 +297,15 @@ llama_context::llama_context(
|
|
|
224
297
|
|
|
225
298
|
backend_buft.push_back(buft);
|
|
226
299
|
backend_ptrs.push_back(backend.get());
|
|
300
|
+
backend_buf_exp_size.push_back(0);
|
|
227
301
|
}
|
|
228
302
|
|
|
229
303
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
|
230
304
|
|
|
231
|
-
const
|
|
305
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
306
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
307
|
+
|
|
308
|
+
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
|
232
309
|
|
|
233
310
|
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
234
311
|
|
|
@@ -239,8 +316,8 @@ llama_context::llama_context(
|
|
|
239
316
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
240
317
|
bool pipeline_parallel =
|
|
241
318
|
model.n_devices() > 1 &&
|
|
242
|
-
model.
|
|
243
|
-
model.
|
|
319
|
+
model.n_gpu_layers() > model.hparams.n_layer &&
|
|
320
|
+
model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
|
|
244
321
|
cparams.offload_kqv &&
|
|
245
322
|
!model.has_tensor_overrides();
|
|
246
323
|
|
|
@@ -268,9 +345,7 @@ llama_context::llama_context(
|
|
|
268
345
|
if (pipeline_parallel) {
|
|
269
346
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
|
270
347
|
}
|
|
271
|
-
}
|
|
272
348
|
|
|
273
|
-
if (!hparams.vocab_only) {
|
|
274
349
|
llama_memory_context_ptr mctx;
|
|
275
350
|
if (memory) {
|
|
276
351
|
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
|
@@ -282,9 +357,6 @@ llama_context::llama_context(
|
|
|
282
357
|
|
|
283
358
|
cross.v_embd.clear();
|
|
284
359
|
|
|
285
|
-
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
|
286
|
-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
287
|
-
|
|
288
360
|
// avoid reserving graphs with zero outputs - assume one output per sequence
|
|
289
361
|
n_outputs = n_seqs;
|
|
290
362
|
|
|
@@ -341,9 +413,17 @@ llama_context::llama_context(
|
|
|
341
413
|
|
|
342
414
|
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
|
343
415
|
{
|
|
344
|
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()
|
|
416
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
|
417
|
+
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
|
345
418
|
if (!gf) {
|
|
346
|
-
|
|
419
|
+
if (pipeline_parallel) {
|
|
420
|
+
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
|
421
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
|
|
422
|
+
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
423
|
+
}
|
|
424
|
+
if (!gf) {
|
|
425
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
426
|
+
}
|
|
347
427
|
}
|
|
348
428
|
|
|
349
429
|
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
@@ -352,7 +432,7 @@ llama_context::llama_context(
|
|
|
352
432
|
|
|
353
433
|
// reserve with tg (token generation) graph to get the number of splits and nodes
|
|
354
434
|
{
|
|
355
|
-
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
|
435
|
+
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
|
356
436
|
if (!gf) {
|
|
357
437
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
358
438
|
}
|
|
@@ -367,7 +447,7 @@ llama_context::llama_context(
|
|
|
367
447
|
//
|
|
368
448
|
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
|
369
449
|
//
|
|
370
|
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
450
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
|
|
371
451
|
if (!gf) {
|
|
372
452
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
373
453
|
}
|
|
@@ -376,11 +456,13 @@ llama_context::llama_context(
|
|
|
376
456
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
377
457
|
ggml_backend_t backend = backend_ptrs[i];
|
|
378
458
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
379
|
-
|
|
380
|
-
|
|
459
|
+
if (!model.hparams.no_alloc) {
|
|
460
|
+
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
461
|
+
}
|
|
462
|
+
if (backend_buf_exp_size[i] > 1) {
|
|
381
463
|
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
382
464
|
ggml_backend_buft_name(buft),
|
|
383
|
-
|
|
465
|
+
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
|
384
466
|
}
|
|
385
467
|
}
|
|
386
468
|
|
|
@@ -396,9 +478,35 @@ llama_context::llama_context(
|
|
|
396
478
|
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
397
479
|
}
|
|
398
480
|
}
|
|
481
|
+
|
|
482
|
+
// Initialize the full vocabulary token ids for backend samplers.
|
|
483
|
+
{
|
|
484
|
+
const int n_vocab = model.vocab.n_tokens();
|
|
485
|
+
|
|
486
|
+
sampling.token_ids_full_vocab.resize(n_vocab);
|
|
487
|
+
for (int i = 0; i < n_vocab; ++i) {
|
|
488
|
+
sampling.token_ids_full_vocab[i] = i;
|
|
489
|
+
}
|
|
490
|
+
}
|
|
399
491
|
}
|
|
400
492
|
|
|
401
493
|
llama_context::~llama_context() {
|
|
494
|
+
if (!model.hparams.no_alloc) {
|
|
495
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
496
|
+
ggml_backend_t backend = backend_ptrs[i];
|
|
497
|
+
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
498
|
+
|
|
499
|
+
const size_t size_exp = backend_buf_exp_size[i];
|
|
500
|
+
const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
501
|
+
if (size_exp == size_act) {
|
|
502
|
+
LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
|
|
503
|
+
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
504
|
+
} else {
|
|
505
|
+
LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
|
506
|
+
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
}
|
|
402
510
|
ggml_opt_free(opt_ctx);
|
|
403
511
|
}
|
|
404
512
|
|
|
@@ -448,8 +556,8 @@ uint32_t llama_context::n_ctx() const {
|
|
|
448
556
|
return cparams.n_ctx;
|
|
449
557
|
}
|
|
450
558
|
|
|
451
|
-
uint32_t llama_context::
|
|
452
|
-
return cparams.
|
|
559
|
+
uint32_t llama_context::n_ctx_seq() const {
|
|
560
|
+
return cparams.n_ctx_seq;
|
|
453
561
|
}
|
|
454
562
|
|
|
455
563
|
uint32_t llama_context::n_batch() const {
|
|
@@ -518,7 +626,7 @@ bool llama_context::memory_update(bool optimize) {
|
|
|
518
626
|
throw std::runtime_error("failed to initialize memory context");
|
|
519
627
|
}
|
|
520
628
|
|
|
521
|
-
const uint32_t n_seqs = cparams.
|
|
629
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
522
630
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
523
631
|
|
|
524
632
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
@@ -540,6 +648,35 @@ float * llama_context::get_logits() {
|
|
|
540
648
|
return logits;
|
|
541
649
|
}
|
|
542
650
|
|
|
651
|
+
int64_t llama_context::output_resolve_row(int32_t i) const {
|
|
652
|
+
int64_t j = -1;
|
|
653
|
+
|
|
654
|
+
// support negative indices (last output row)
|
|
655
|
+
if (i < 0) {
|
|
656
|
+
j = n_outputs + i;
|
|
657
|
+
if (j < 0) {
|
|
658
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
659
|
+
}
|
|
660
|
+
} else if ((size_t) i >= output_ids.size()) {
|
|
661
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
662
|
+
} else {
|
|
663
|
+
// use output_ids to translate the batch token index into a row number
|
|
664
|
+
// that holds this token's data.
|
|
665
|
+
j = output_ids[i];
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
if (j < 0) {
|
|
669
|
+
// the batch token was not configured to output anything
|
|
670
|
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
if (j >= n_outputs) {
|
|
674
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
return j;
|
|
678
|
+
}
|
|
679
|
+
|
|
543
680
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
544
681
|
int64_t j = -1;
|
|
545
682
|
|
|
@@ -550,6 +687,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
550
687
|
throw std::runtime_error("no logits");
|
|
551
688
|
}
|
|
552
689
|
|
|
690
|
+
// TODO: use output_resolve_row()
|
|
553
691
|
if (i < 0) {
|
|
554
692
|
j = n_outputs + i;
|
|
555
693
|
if (j < 0) {
|
|
@@ -586,6 +724,10 @@ float * llama_context::get_embeddings() {
|
|
|
586
724
|
return embd;
|
|
587
725
|
}
|
|
588
726
|
|
|
727
|
+
llama_token * llama_context::get_sampled_tokens() const{
|
|
728
|
+
return sampling.sampled;
|
|
729
|
+
}
|
|
730
|
+
|
|
589
731
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
590
732
|
int64_t j = -1;
|
|
591
733
|
|
|
@@ -596,6 +738,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
596
738
|
throw std::runtime_error("no embeddings");
|
|
597
739
|
}
|
|
598
740
|
|
|
741
|
+
// TODO: use output_resolve_row()
|
|
599
742
|
if (i < 0) {
|
|
600
743
|
j = n_outputs + i;
|
|
601
744
|
if (j < 0) {
|
|
@@ -615,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
|
615
758
|
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
616
759
|
}
|
|
617
760
|
|
|
618
|
-
|
|
761
|
+
const uint32_t n_embd_out = model.hparams.get_n_embd_out();
|
|
762
|
+
return embd + j*n_embd_out;
|
|
619
763
|
} catch (const std::exception & err) {
|
|
620
764
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
621
765
|
#ifndef NDEBUG
|
|
@@ -635,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
635
779
|
return it->second.data();
|
|
636
780
|
}
|
|
637
781
|
|
|
782
|
+
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
|
783
|
+
output_reorder();
|
|
784
|
+
|
|
785
|
+
if (sampling.sampled == nullptr) {
|
|
786
|
+
return LLAMA_TOKEN_NULL;
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
try {
|
|
790
|
+
const int64_t row = output_resolve_row(idx);
|
|
791
|
+
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
|
|
792
|
+
return sampling.sampled[row];
|
|
793
|
+
} catch (const std::exception & err) {
|
|
794
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
|
795
|
+
return LLAMA_TOKEN_NULL;
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|
800
|
+
output_reorder();
|
|
801
|
+
|
|
802
|
+
if (sampling.probs == nullptr) {
|
|
803
|
+
return nullptr;
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
try {
|
|
807
|
+
const int64_t row = output_resolve_row(idx);
|
|
808
|
+
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
|
809
|
+
return nullptr;
|
|
810
|
+
}
|
|
811
|
+
return sampling.probs + row*model.vocab.n_tokens();
|
|
812
|
+
} catch (const std::exception & err) {
|
|
813
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
|
814
|
+
return nullptr;
|
|
815
|
+
}
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
|
819
|
+
output_reorder();
|
|
820
|
+
|
|
821
|
+
if (sampling.logits == nullptr) {
|
|
822
|
+
return nullptr;
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
try {
|
|
826
|
+
const int64_t row = output_resolve_row(idx);
|
|
827
|
+
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
|
828
|
+
return nullptr;
|
|
829
|
+
}
|
|
830
|
+
return sampling.logits + row*model.vocab.n_tokens();
|
|
831
|
+
} catch (const std::exception & err) {
|
|
832
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
|
833
|
+
return nullptr;
|
|
834
|
+
}
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|
838
|
+
output_reorder();
|
|
839
|
+
|
|
840
|
+
try {
|
|
841
|
+
const int64_t row = output_resolve_row(idx);
|
|
842
|
+
if (sampling.candidates != nullptr &&
|
|
843
|
+
(size_t) row < sampling.candidates_count.size() &&
|
|
844
|
+
sampling.candidates_count[row] > 0) {
|
|
845
|
+
return sampling.candidates + row*model.vocab.n_tokens();
|
|
846
|
+
}
|
|
847
|
+
} catch (const std::exception & err) {
|
|
848
|
+
// fallback to full vocab list
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
return sampling.token_ids_full_vocab.data();
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
|
855
|
+
output_reorder();
|
|
856
|
+
|
|
857
|
+
if (sampling.candidates == nullptr) {
|
|
858
|
+
return 0;
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
try {
|
|
862
|
+
const int64_t row = output_resolve_row(idx);
|
|
863
|
+
if ((size_t) row >= sampling.candidates_count.size()) {
|
|
864
|
+
return 0;
|
|
865
|
+
}
|
|
866
|
+
return sampling.candidates_count[row];
|
|
867
|
+
} catch (const std::exception & err) {
|
|
868
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
|
|
869
|
+
return 0;
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|
874
|
+
output_reorder();
|
|
875
|
+
|
|
876
|
+
if (sampling.logits == nullptr) {
|
|
877
|
+
return model.vocab.n_tokens();
|
|
878
|
+
}
|
|
879
|
+
|
|
880
|
+
try {
|
|
881
|
+
const int64_t row = output_resolve_row(idx);
|
|
882
|
+
if ((size_t) row >= sampling.logits_count.size()) {
|
|
883
|
+
return 0;
|
|
884
|
+
}
|
|
885
|
+
return sampling.logits_count[row];
|
|
886
|
+
} catch (const std::exception & err) {
|
|
887
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
|
|
888
|
+
return 0;
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
|
893
|
+
output_reorder();
|
|
894
|
+
|
|
895
|
+
if (sampling.probs == nullptr) {
|
|
896
|
+
return 0;
|
|
897
|
+
}
|
|
898
|
+
|
|
899
|
+
try {
|
|
900
|
+
const int64_t row = output_resolve_row(idx);
|
|
901
|
+
if ((size_t) row >= sampling.probs_count.size()) {
|
|
902
|
+
return 0;
|
|
903
|
+
}
|
|
904
|
+
return sampling.probs_count[row];
|
|
905
|
+
} catch (const std::exception & err) {
|
|
906
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
|
|
907
|
+
return 0;
|
|
908
|
+
}
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
|
|
638
912
|
void llama_context::attach_threadpool(
|
|
639
913
|
ggml_threadpool_t threadpool,
|
|
640
914
|
ggml_threadpool_t threadpool_batch) {
|
|
@@ -691,6 +965,42 @@ void llama_context::set_warmup(bool value) {
|
|
|
691
965
|
cparams.warmup = value;
|
|
692
966
|
}
|
|
693
967
|
|
|
968
|
+
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|
969
|
+
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
|
970
|
+
|
|
971
|
+
const bool can_offload =
|
|
972
|
+
sampler &&
|
|
973
|
+
sampler->iface->backend_init &&
|
|
974
|
+
sampler->iface->backend_apply &&
|
|
975
|
+
llama_sampler_chain_n(sampler) > 0;
|
|
976
|
+
|
|
977
|
+
if (sampler && can_offload) {
|
|
978
|
+
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
|
979
|
+
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
|
|
980
|
+
if (host_buft) {
|
|
981
|
+
buft = host_buft;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
sampler->iface->backend_init(sampler, buft);
|
|
985
|
+
|
|
986
|
+
sampling.samplers[seq_id] = sampler;
|
|
987
|
+
|
|
988
|
+
return true;
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
if (sampler && !can_offload) {
|
|
992
|
+
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
|
993
|
+
|
|
994
|
+
sampling.samplers.erase(seq_id);
|
|
995
|
+
|
|
996
|
+
return false;
|
|
997
|
+
}
|
|
998
|
+
|
|
999
|
+
sampling.samplers.erase(seq_id);
|
|
1000
|
+
|
|
1001
|
+
return true;
|
|
1002
|
+
}
|
|
1003
|
+
|
|
694
1004
|
void llama_context::set_adapter_lora(
|
|
695
1005
|
llama_adapter_lora * adapter,
|
|
696
1006
|
float scale) {
|
|
@@ -803,7 +1113,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
803
1113
|
|
|
804
1114
|
const auto & hparams = model.hparams;
|
|
805
1115
|
|
|
806
|
-
const int64_t n_embd = hparams.
|
|
1116
|
+
const int64_t n_embd = hparams.n_embd_inp();
|
|
807
1117
|
const int64_t n_vocab = model.vocab.n_tokens();
|
|
808
1118
|
|
|
809
1119
|
// note: during encode, we always pass the full sequence starting from pos = 0
|
|
@@ -831,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
831
1141
|
n_queued_tokens += n_tokens;
|
|
832
1142
|
|
|
833
1143
|
// reserve output buffer
|
|
834
|
-
if (output_reserve(n_tokens) < n_tokens) {
|
|
1144
|
+
if (output_reserve(n_tokens, batch_inp) < n_tokens) {
|
|
835
1145
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
836
1146
|
return -2;
|
|
837
1147
|
};
|
|
@@ -885,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
885
1195
|
{
|
|
886
1196
|
// extract token embeddings
|
|
887
1197
|
GGML_ASSERT(embd != nullptr);
|
|
1198
|
+
const uint32_t n_embd_out = hparams.get_n_embd_out();
|
|
888
1199
|
|
|
889
|
-
GGML_ASSERT(n_tokens*
|
|
890
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*
|
|
1200
|
+
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
|
|
1201
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
|
|
891
1202
|
} break;
|
|
892
1203
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
893
1204
|
case LLAMA_POOLING_TYPE_CLS:
|
|
@@ -955,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
955
1266
|
return 0;
|
|
956
1267
|
}
|
|
957
1268
|
|
|
1269
|
+
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
|
1270
|
+
std::map<llama_seq_id, uint32_t> seq_to_row;
|
|
1271
|
+
// how many output tokens we have seen so far for this ubatch.
|
|
1272
|
+
uint32_t local = 0;
|
|
1273
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
1274
|
+
// skip tokens that are not output.
|
|
1275
|
+
if (!ubatch.output[i]) {
|
|
1276
|
+
continue;
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
1280
|
+
// row_offset is the number of output tokens before this ubatch.
|
|
1281
|
+
seq_to_row[seq_id] = row_offset + local;
|
|
1282
|
+
++local;
|
|
1283
|
+
}
|
|
1284
|
+
return seq_to_row;
|
|
1285
|
+
}
|
|
1286
|
+
|
|
1287
|
+
static void copy_tensor_async_ints(
|
|
1288
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1289
|
+
llama_token * sampled,
|
|
1290
|
+
size_t sampled_size,
|
|
1291
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1292
|
+
ggml_backend_sched_t sched) {
|
|
1293
|
+
if (sampled == nullptr) {
|
|
1294
|
+
return;
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1298
|
+
auto it = seq_to_row.find(seq_id);
|
|
1299
|
+
if (it == seq_to_row.end()) {
|
|
1300
|
+
continue;
|
|
1301
|
+
}
|
|
1302
|
+
|
|
1303
|
+
const uint32_t row = it->second;
|
|
1304
|
+
GGML_ASSERT(row < sampled_size);
|
|
1305
|
+
|
|
1306
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
|
1307
|
+
|
|
1308
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1309
|
+
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
|
|
1310
|
+
}
|
|
1311
|
+
}
|
|
1312
|
+
|
|
1313
|
+
static void copy_tensor_async_floats(
|
|
1314
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1315
|
+
float * dst,
|
|
1316
|
+
size_t stride,
|
|
1317
|
+
std::vector<uint32_t> & counts,
|
|
1318
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1319
|
+
ggml_backend_sched_t sched) {
|
|
1320
|
+
if (dst == nullptr) {
|
|
1321
|
+
return;
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1325
|
+
auto it = seq_to_row.find(seq_id);
|
|
1326
|
+
if (it == seq_to_row.end()) {
|
|
1327
|
+
continue;
|
|
1328
|
+
}
|
|
1329
|
+
|
|
1330
|
+
const uint32_t row = it->second;
|
|
1331
|
+
GGML_ASSERT(row < counts.size());
|
|
1332
|
+
|
|
1333
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
|
1334
|
+
|
|
1335
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1336
|
+
float * row_ptr = dst + (size_t) row * stride;
|
|
1337
|
+
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1338
|
+
|
|
1339
|
+
// Update the actual number of logits/probabilities that were written for this row.
|
|
1340
|
+
counts[row] = ggml_nelements(tensor);
|
|
1341
|
+
}
|
|
1342
|
+
}
|
|
1343
|
+
|
|
1344
|
+
static void copy_tensor_async_candidates(
|
|
1345
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1346
|
+
llama_token * dst,
|
|
1347
|
+
size_t stride,
|
|
1348
|
+
std::vector<uint32_t> & counts,
|
|
1349
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1350
|
+
ggml_backend_sched_t sched) {
|
|
1351
|
+
if (dst == nullptr) {
|
|
1352
|
+
return;
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1356
|
+
auto it = seq_to_row.find(seq_id);
|
|
1357
|
+
if (it == seq_to_row.end()) {
|
|
1358
|
+
continue;
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
const uint32_t row = it->second;
|
|
1362
|
+
GGML_ASSERT(row < counts.size());
|
|
1363
|
+
|
|
1364
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
|
1365
|
+
|
|
1366
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1367
|
+
llama_token * row_ptr = dst + (size_t) row * stride;
|
|
1368
|
+
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1369
|
+
|
|
1370
|
+
// Update the actual number of candidates that were written.
|
|
1371
|
+
counts[row] = ggml_nelements(tensor);
|
|
1372
|
+
}
|
|
1373
|
+
}
|
|
1374
|
+
|
|
958
1375
|
int llama_context::decode(const llama_batch & batch_inp) {
|
|
959
1376
|
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
|
960
1377
|
|
|
@@ -972,12 +1389,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
972
1389
|
const auto & hparams = model.hparams;
|
|
973
1390
|
|
|
974
1391
|
const int64_t n_vocab = vocab.n_tokens();
|
|
975
|
-
const int64_t n_embd = hparams.
|
|
1392
|
+
const int64_t n_embd = hparams.n_embd_inp();
|
|
976
1393
|
|
|
977
1394
|
// when computing embeddings, all tokens are output
|
|
978
|
-
const bool output_all
|
|
1395
|
+
const bool output_all = cparams.embeddings;
|
|
1396
|
+
const bool has_samplers = !sampling.samplers.empty();
|
|
1397
|
+
|
|
1398
|
+
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
|
|
1399
|
+
|
|
1400
|
+
// TODO: avoid this workaround in the future
|
|
1401
|
+
if (has_samplers && batch_inp.logits) {
|
|
1402
|
+
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
|
1403
|
+
|
|
1404
|
+
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
|
|
1405
|
+
if (batch_inp.logits[i] == 0) {
|
|
1406
|
+
continue;
|
|
1407
|
+
}
|
|
979
1408
|
|
|
980
|
-
|
|
1409
|
+
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
|
|
1410
|
+
|
|
1411
|
+
for (int32_t s = 0; s < ns; ++s) {
|
|
1412
|
+
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
|
|
1413
|
+
|
|
1414
|
+
seq_output_count[seq_id]++;
|
|
1415
|
+
if (seq_output_count[seq_id] > 1) {
|
|
1416
|
+
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
|
|
1417
|
+
__func__, seq_id, seq_output_count[seq_id]);
|
|
1418
|
+
return -1;
|
|
1419
|
+
}
|
|
1420
|
+
}
|
|
1421
|
+
}
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
|
|
981
1425
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
982
1426
|
return -1;
|
|
983
1427
|
}
|
|
@@ -1058,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1058
1502
|
}
|
|
1059
1503
|
|
|
1060
1504
|
// reserve output buffer
|
|
1061
|
-
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1505
|
+
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
|
1062
1506
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
1063
1507
|
return -2;
|
|
1064
1508
|
};
|
|
@@ -1131,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1131
1575
|
}
|
|
1132
1576
|
|
|
1133
1577
|
// extract logits
|
|
1134
|
-
|
|
1578
|
+
// For multi-sequence batches that mix backend samplers and CPU sampler
|
|
1579
|
+
// this is currently inefficient as we copy all logits even for the
|
|
1580
|
+
// backend sampled tokens.
|
|
1581
|
+
if (logits && t_logits && n_outputs > 0) {
|
|
1135
1582
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
1136
1583
|
GGML_ASSERT(backend_res != nullptr);
|
|
1137
1584
|
GGML_ASSERT(logits != nullptr);
|
|
@@ -1146,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1146
1593
|
}
|
|
1147
1594
|
|
|
1148
1595
|
// extract embeddings
|
|
1149
|
-
if (t_embd && n_outputs > 0) {
|
|
1596
|
+
if (embd && t_embd && n_outputs > 0) {
|
|
1150
1597
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1151
1598
|
GGML_ASSERT(backend_embd != nullptr);
|
|
1152
1599
|
|
|
@@ -1155,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1155
1602
|
{
|
|
1156
1603
|
// extract token embeddings
|
|
1157
1604
|
GGML_ASSERT(embd != nullptr);
|
|
1158
|
-
|
|
1605
|
+
const uint32_t n_embd_out = hparams.get_n_embd_out();
|
|
1606
|
+
float * embd_out = embd + n_outputs_prev*n_embd_out;
|
|
1159
1607
|
|
|
1160
1608
|
if (n_outputs) {
|
|
1161
1609
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1162
|
-
GGML_ASSERT((n_outputs_prev + n_outputs)*
|
|
1163
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*
|
|
1610
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
|
|
1611
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
|
|
1164
1612
|
}
|
|
1165
1613
|
} break;
|
|
1166
1614
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
@@ -1200,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1200
1648
|
}
|
|
1201
1649
|
}
|
|
1202
1650
|
|
|
1651
|
+
// This flag indicates whether a backend sampler has actually sampled a specific
|
|
1652
|
+
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
|
|
1653
|
+
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
|
1654
|
+
|
|
1655
|
+
if (has_samplers && has_sampled) {
|
|
1656
|
+
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
|
1657
|
+
const auto stride = n_vocab;
|
|
1658
|
+
|
|
1659
|
+
// async copy the sampling data from the backend to the host
|
|
1660
|
+
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
|
|
1661
|
+
|
|
1662
|
+
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
|
1663
|
+
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
|
1664
|
+
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1203
1667
|
n_outputs_prev += n_outputs;
|
|
1204
1668
|
} while (mctx->next());
|
|
1205
1669
|
|
|
@@ -1224,7 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1224
1688
|
|
|
1225
1689
|
// make the outputs have the same order they had in the user-provided batch
|
|
1226
1690
|
// note: this is mostly relevant for recurrent models atm
|
|
1227
|
-
if (!sorted_output) {
|
|
1691
|
+
if (!sorted_output && n_outputs > 1) {
|
|
1228
1692
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1229
1693
|
|
|
1230
1694
|
// TODO: is there something more efficient which also minimizes swaps?
|
|
@@ -1263,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1263
1727
|
// output
|
|
1264
1728
|
//
|
|
1265
1729
|
|
|
1266
|
-
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1730
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
|
|
1267
1731
|
const auto & hparams = model.hparams;
|
|
1268
1732
|
const auto & vocab = model.vocab;
|
|
1269
1733
|
|
|
1270
1734
|
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
|
|
1271
1735
|
|
|
1272
|
-
const auto n_batch
|
|
1273
|
-
const auto n_vocab
|
|
1274
|
-
const auto
|
|
1736
|
+
const auto n_batch = cparams.n_batch;
|
|
1737
|
+
const auto n_vocab = vocab.n_tokens();
|
|
1738
|
+
const auto n_embd_out = hparams.get_n_embd_out();
|
|
1275
1739
|
|
|
1276
1740
|
bool has_logits = true;
|
|
1277
1741
|
bool has_embd = cparams.embeddings;
|
|
@@ -1282,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1282
1746
|
has_embd = true;
|
|
1283
1747
|
}
|
|
1284
1748
|
|
|
1285
|
-
|
|
1286
|
-
|
|
1749
|
+
// Check which sampling modes are needed for the current batch.
|
|
1750
|
+
// TODO: avoid this branching by working with the worst-case
|
|
1751
|
+
bool has_sampling = false;
|
|
1752
|
+
bool cpu_logits = false;
|
|
1753
|
+
|
|
1754
|
+
if (batch.logits) {
|
|
1755
|
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
1756
|
+
if (!batch.logits[i]) {
|
|
1757
|
+
continue;
|
|
1758
|
+
}
|
|
1759
|
+
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
|
1760
|
+
llama_seq_id seq_id = batch.seq_id[i][j];
|
|
1761
|
+
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
|
1762
|
+
has_sampling = true;
|
|
1763
|
+
} else {
|
|
1764
|
+
cpu_logits = true;
|
|
1765
|
+
}
|
|
1766
|
+
}
|
|
1767
|
+
}
|
|
1768
|
+
} else {
|
|
1769
|
+
// When batch.logits is nullptr (when loading state with a dummy batch),
|
|
1770
|
+
// allocate CPU logits.
|
|
1771
|
+
cpu_logits = true;
|
|
1772
|
+
}
|
|
1773
|
+
|
|
1774
|
+
size_t backend_float_count = 0;
|
|
1775
|
+
size_t backend_token_count = 0;
|
|
1776
|
+
|
|
1777
|
+
// Allocate CPU logits buffer only if needed by sequences in this batch
|
|
1778
|
+
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
|
|
1779
|
+
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
|
|
1780
|
+
|
|
1781
|
+
// TODO: avoid this branching by working with the worst-case
|
|
1782
|
+
if (!has_sampling) {
|
|
1783
|
+
sampling.logits_size = 0;
|
|
1784
|
+
sampling.probs_size = 0;
|
|
1785
|
+
sampling.sampled_size = 0;
|
|
1786
|
+
sampling.candidates_size = 0;
|
|
1787
|
+
} else {
|
|
1788
|
+
sampling.logits_size = n_vocab*n_outputs_max;
|
|
1789
|
+
sampling.probs_size = n_vocab*n_outputs_max;
|
|
1790
|
+
sampling.sampled_size = n_outputs_max;
|
|
1791
|
+
sampling.candidates_size = n_vocab*n_outputs_max;
|
|
1792
|
+
|
|
1793
|
+
backend_float_count = sampling.logits_size + sampling.probs_size;
|
|
1794
|
+
backend_token_count = sampling.sampled_size + sampling.candidates_size;
|
|
1795
|
+
}
|
|
1287
1796
|
|
|
1288
1797
|
if (output_ids.empty()) {
|
|
1289
1798
|
// init, never resized afterwards
|
|
@@ -1291,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1291
1800
|
}
|
|
1292
1801
|
|
|
1293
1802
|
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
|
1294
|
-
const size_t new_size =
|
|
1803
|
+
const size_t new_size =
|
|
1804
|
+
(logits_size + embd_size + backend_float_count) * sizeof(float) +
|
|
1805
|
+
( backend_token_count) * sizeof(llama_token);
|
|
1295
1806
|
|
|
1296
1807
|
// alloc only when more than the current capacity is required
|
|
1297
1808
|
// TODO: also consider shrinking the buffer
|
|
@@ -1299,8 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1299
1810
|
if (buf_output) {
|
|
1300
1811
|
#ifndef NDEBUG
|
|
1301
1812
|
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
|
1302
|
-
|
|
1813
|
+
LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
1303
1814
|
#endif
|
|
1815
|
+
synchronize();
|
|
1816
|
+
|
|
1817
|
+
// TODO: not needed?
|
|
1304
1818
|
buf_output = nullptr;
|
|
1305
1819
|
logits = nullptr;
|
|
1306
1820
|
embd = nullptr;
|
|
@@ -1322,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1322
1836
|
|
|
1323
1837
|
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
|
1324
1838
|
|
|
1325
|
-
logits =
|
|
1326
|
-
embd =
|
|
1839
|
+
logits = nullptr;
|
|
1840
|
+
embd = nullptr;
|
|
1841
|
+
|
|
1842
|
+
size_t offset = 0;
|
|
1843
|
+
uint8_t * base = (uint8_t *) output_base;
|
|
1844
|
+
|
|
1845
|
+
logits = (has_logits && cpu_logits) ? output_base : nullptr;
|
|
1846
|
+
offset += logits_size * sizeof(float);
|
|
1847
|
+
|
|
1848
|
+
embd = has_embd ? (float *) (base + offset) : nullptr;
|
|
1849
|
+
offset += embd_size * sizeof(float);
|
|
1850
|
+
|
|
1851
|
+
sampling.logits = nullptr;
|
|
1852
|
+
sampling.probs = nullptr;
|
|
1853
|
+
sampling.sampled = nullptr;
|
|
1854
|
+
sampling.candidates = nullptr;
|
|
1855
|
+
|
|
1856
|
+
if (has_sampling) {
|
|
1857
|
+
sampling.logits = (float *) (base + offset);
|
|
1858
|
+
offset += sampling.logits_size * sizeof(float);
|
|
1859
|
+
|
|
1860
|
+
sampling.probs = (float *) (base + offset);
|
|
1861
|
+
offset += sampling.probs_size * sizeof(float);
|
|
1862
|
+
|
|
1863
|
+
sampling.sampled = (llama_token *) (base + offset);
|
|
1864
|
+
offset += sampling.sampled_size * sizeof(llama_token);
|
|
1865
|
+
|
|
1866
|
+
sampling.candidates = (llama_token *) (base + offset);
|
|
1867
|
+
offset += sampling.candidates_size * sizeof(llama_token);
|
|
1868
|
+
|
|
1869
|
+
// The count vectors keep track of the actual number of logits/probs/candidates
|
|
1870
|
+
// copied from the backend for each output row.
|
|
1871
|
+
|
|
1872
|
+
sampling.logits_count.resize(n_outputs_max);
|
|
1873
|
+
sampling.probs_count.resize(n_outputs_max);
|
|
1874
|
+
sampling.candidates_count.resize(n_outputs_max);
|
|
1875
|
+
|
|
1876
|
+
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
|
1877
|
+
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
|
1878
|
+
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
|
1879
|
+
|
|
1880
|
+
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
|
1881
|
+
}
|
|
1327
1882
|
|
|
1328
1883
|
// set all ids as invalid (negative)
|
|
1329
1884
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
@@ -1352,6 +1907,40 @@ void llama_context::output_reorder() {
|
|
|
1352
1907
|
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
|
1353
1908
|
}
|
|
1354
1909
|
}
|
|
1910
|
+
|
|
1911
|
+
if (sampling.logits && sampling.logits_size > 0) {
|
|
1912
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1913
|
+
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
|
|
1914
|
+
}
|
|
1915
|
+
}
|
|
1916
|
+
|
|
1917
|
+
if (sampling.probs && sampling.probs_size > 0) {
|
|
1918
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1919
|
+
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
|
|
1920
|
+
}
|
|
1921
|
+
}
|
|
1922
|
+
|
|
1923
|
+
if (sampling.candidates && sampling.candidates_size > 0) {
|
|
1924
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1925
|
+
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
|
|
1926
|
+
}
|
|
1927
|
+
}
|
|
1928
|
+
|
|
1929
|
+
if (sampling.sampled && sampling.sampled_size > 0) {
|
|
1930
|
+
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
|
|
1931
|
+
}
|
|
1932
|
+
|
|
1933
|
+
if (!sampling.logits_count.empty()) {
|
|
1934
|
+
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
|
1935
|
+
}
|
|
1936
|
+
|
|
1937
|
+
if (!sampling.probs_count.empty()) {
|
|
1938
|
+
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
|
1939
|
+
}
|
|
1940
|
+
|
|
1941
|
+
if (!sampling.candidates_count.empty()) {
|
|
1942
|
+
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
|
1943
|
+
}
|
|
1355
1944
|
}
|
|
1356
1945
|
|
|
1357
1946
|
output_swaps.clear();
|
|
@@ -1361,21 +1950,27 @@ void llama_context::output_reorder() {
|
|
|
1361
1950
|
// graph
|
|
1362
1951
|
//
|
|
1363
1952
|
|
|
1364
|
-
uint32_t llama_context::graph_max_nodes() const {
|
|
1365
|
-
|
|
1953
|
+
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
|
1954
|
+
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
|
1955
|
+
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
|
1956
|
+
}
|
|
1957
|
+
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
|
1958
|
+
res += model.n_lora_nodes;
|
|
1959
|
+
return res;
|
|
1366
1960
|
}
|
|
1367
1961
|
|
|
1368
1962
|
llm_graph_result * llama_context::get_gf_res_reserve() const {
|
|
1369
1963
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
|
1370
1964
|
}
|
|
1371
1965
|
|
|
1372
|
-
ggml_cgraph * llama_context::graph_reserve(
|
|
1966
|
+
ggml_cgraph * llama_context::graph_reserve(
|
|
1967
|
+
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
|
|
1373
1968
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1374
1969
|
GGML_ASSERT(n_outputs >= 1);
|
|
1375
1970
|
|
|
1376
1971
|
if (n_tokens % n_seqs != 0) {
|
|
1377
1972
|
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
|
1378
|
-
n_outputs = std::
|
|
1973
|
+
n_outputs = std::max(n_outputs, n_tokens);
|
|
1379
1974
|
|
|
1380
1975
|
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1381
1976
|
}
|
|
@@ -1394,6 +1989,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
|
1394
1989
|
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
|
1395
1990
|
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
|
1396
1991
|
|
|
1992
|
+
// set one output token per sequence in order to activate all backend samplers
|
|
1993
|
+
std::vector<llama_seq_id> seq_ids(n_seqs);
|
|
1994
|
+
for (uint32_t i = 0; i < n_seqs; ++i) {
|
|
1995
|
+
seq_ids[i] = i;
|
|
1996
|
+
ubatch.n_seq_id[i] = 1;
|
|
1997
|
+
ubatch.seq_id[i] = &seq_ids[i];
|
|
1998
|
+
ubatch.output[i] = true;
|
|
1999
|
+
}
|
|
2000
|
+
|
|
1397
2001
|
auto * res = gf_res_reserve.get();
|
|
1398
2002
|
|
|
1399
2003
|
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
|
@@ -1406,8 +2010,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
|
1406
2010
|
|
|
1407
2011
|
// initialize scheduler with the specified graph
|
|
1408
2012
|
if (split_only) {
|
|
1409
|
-
|
|
2013
|
+
if (sizes) {
|
|
2014
|
+
ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
|
|
2015
|
+
} else {
|
|
2016
|
+
ggml_backend_sched_split_graph(sched.get(), gf);
|
|
2017
|
+
}
|
|
1410
2018
|
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
2019
|
+
GGML_ASSERT(!sizes);
|
|
1411
2020
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
1412
2021
|
return nullptr;
|
|
1413
2022
|
}
|
|
@@ -1419,7 +2028,7 @@ llm_graph_params llama_context::graph_params(
|
|
|
1419
2028
|
llm_graph_result * res,
|
|
1420
2029
|
const llama_ubatch & ubatch,
|
|
1421
2030
|
const llama_memory_context_i * mctx,
|
|
1422
|
-
|
|
2031
|
+
llm_graph_type gtype) const {
|
|
1423
2032
|
return {
|
|
1424
2033
|
/*.arch =*/ model.arch,
|
|
1425
2034
|
/*.hparams =*/ model.hparams,
|
|
@@ -1432,6 +2041,7 @@ llm_graph_params llama_context::graph_params(
|
|
|
1432
2041
|
/*.loras =*/ &loras,
|
|
1433
2042
|
/*.mctx =*/ mctx,
|
|
1434
2043
|
/*.cross =*/ &cross,
|
|
2044
|
+
/*.samplers =*/ sampling.samplers,
|
|
1435
2045
|
/*.n_outputs =*/ n_outputs,
|
|
1436
2046
|
/*.cb =*/ graph_get_cb(),
|
|
1437
2047
|
/*.res =*/ res,
|
|
@@ -1484,7 +2094,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
|
|
1484
2094
|
|
|
1485
2095
|
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
|
1486
2096
|
// FIXME: fix in ggml_backend_sched
|
|
1487
|
-
const bool full_offload = model.
|
|
2097
|
+
const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
|
|
1488
2098
|
if (ubatch.n_tokens < 32 || full_offload) {
|
|
1489
2099
|
if (il != -1 && strcmp(name, "norm") == 0) {
|
|
1490
2100
|
const auto & dev_layer = model.dev_layer(il);
|
|
@@ -1887,6 +2497,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1887
2497
|
}
|
|
1888
2498
|
}
|
|
1889
2499
|
|
|
2500
|
+
// TODO: handle sampling buffers and samplers state ?
|
|
2501
|
+
// https://github.com/ggml-org/llama.cpp/pull/17004
|
|
2502
|
+
|
|
1890
2503
|
if (memory != nullptr) {
|
|
1891
2504
|
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
|
1892
2505
|
memory->state_write(io);
|
|
@@ -1919,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
1919
2532
|
auto n_outputs = this->n_outputs;
|
|
1920
2533
|
io.read_to(&n_outputs, sizeof(n_outputs));
|
|
1921
2534
|
|
|
1922
|
-
|
|
2535
|
+
// Create a dummy batch for state loading.
|
|
2536
|
+
llama_batch dummy_batch = {};
|
|
2537
|
+
dummy_batch.n_tokens = 0;
|
|
2538
|
+
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
|
|
1923
2539
|
throw std::runtime_error("could not reserve outputs");
|
|
1924
2540
|
}
|
|
1925
2541
|
|
|
@@ -1973,6 +2589,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
1973
2589
|
}
|
|
1974
2590
|
}
|
|
1975
2591
|
|
|
2592
|
+
// TODO: handle sampling buffers and samplers state ?
|
|
2593
|
+
// https://github.com/ggml-org/llama.cpp/pull/17004
|
|
2594
|
+
|
|
1976
2595
|
if (memory) {
|
|
1977
2596
|
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
|
1978
2597
|
|
|
@@ -2029,15 +2648,26 @@ void llama_context::perf_reset() {
|
|
|
2029
2648
|
|
|
2030
2649
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
|
2031
2650
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
|
2032
|
-
for (const auto &
|
|
2033
|
-
ret[
|
|
2651
|
+
for (const auto & [buft, size] : model.memory_breakdown()) {
|
|
2652
|
+
ret[buft].model += size;
|
|
2034
2653
|
}
|
|
2035
|
-
|
|
2036
|
-
|
|
2654
|
+
if (memory) {
|
|
2655
|
+
for (const auto & [buft, size] : memory->memory_breakdown()) {
|
|
2656
|
+
ret[buft].context += size;
|
|
2657
|
+
}
|
|
2037
2658
|
}
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2659
|
+
if (model.hparams.no_alloc) {
|
|
2660
|
+
for (size_t i = 0; i < backends.size(); ++i) {
|
|
2661
|
+
ggml_backend_t backend = backends[i].get();
|
|
2662
|
+
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
|
2663
|
+
ret[buft].compute += backend_buf_exp_size[i];
|
|
2664
|
+
}
|
|
2665
|
+
} else {
|
|
2666
|
+
for (const auto & backend_ptr : backends) {
|
|
2667
|
+
ggml_backend_t backend = backend_ptr.get();
|
|
2668
|
+
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
|
2669
|
+
ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
2670
|
+
}
|
|
2041
2671
|
}
|
|
2042
2672
|
return ret;
|
|
2043
2673
|
}
|
|
@@ -2130,7 +2760,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2130
2760
|
batch.logits [pos_batch] = true;
|
|
2131
2761
|
}
|
|
2132
2762
|
|
|
2133
|
-
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.
|
|
2763
|
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
|
2134
2764
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
2135
2765
|
return;
|
|
2136
2766
|
}
|
|
@@ -2150,7 +2780,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2150
2780
|
}
|
|
2151
2781
|
|
|
2152
2782
|
// reserve output buffer
|
|
2153
|
-
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
2783
|
+
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
|
2154
2784
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
2155
2785
|
GGML_ABORT("TODO: handle this error");
|
|
2156
2786
|
};
|
|
@@ -2295,6 +2925,8 @@ llama_context_params llama_context_default_params() {
|
|
|
2295
2925
|
/*.op_offload =*/ true,
|
|
2296
2926
|
/*.swa_full =*/ true,
|
|
2297
2927
|
/*.kv_unified =*/ false,
|
|
2928
|
+
/*.sampler =*/ nullptr,
|
|
2929
|
+
/*.n_sampler =*/ 0,
|
|
2298
2930
|
};
|
|
2299
2931
|
|
|
2300
2932
|
return result;
|
|
@@ -2346,6 +2978,13 @@ llama_context * llama_init_from_model(
|
|
|
2346
2978
|
return nullptr;
|
|
2347
2979
|
}
|
|
2348
2980
|
|
|
2981
|
+
if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
|
|
2982
|
+
params.pooling_type != model->hparams.pooling_type) {
|
|
2983
|
+
//user-specified pooling-type is different from the model default
|
|
2984
|
+
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
|
|
2985
|
+
model->hparams.pooling_type, params.pooling_type);
|
|
2986
|
+
}
|
|
2987
|
+
|
|
2349
2988
|
try {
|
|
2350
2989
|
auto * ctx = new llama_context(*model, params);
|
|
2351
2990
|
return ctx;
|
|
@@ -2371,6 +3010,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
|
|
2371
3010
|
return ctx->n_ctx();
|
|
2372
3011
|
}
|
|
2373
3012
|
|
|
3013
|
+
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
|
3014
|
+
return ctx->n_ctx_seq();
|
|
3015
|
+
}
|
|
3016
|
+
|
|
2374
3017
|
uint32_t llama_n_batch(const llama_context * ctx) {
|
|
2375
3018
|
return ctx->n_batch();
|
|
2376
3019
|
}
|
|
@@ -2443,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) {
|
|
|
2443
3086
|
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|
2444
3087
|
ctx->synchronize();
|
|
2445
3088
|
|
|
2446
|
-
|
|
3089
|
+
float * res = nullptr;
|
|
3090
|
+
|
|
3091
|
+
res = ctx->get_sampled_logits_ith(i);
|
|
3092
|
+
|
|
3093
|
+
if (!res) {
|
|
3094
|
+
res = ctx->get_logits_ith(i);
|
|
3095
|
+
}
|
|
3096
|
+
|
|
3097
|
+
return res;
|
|
2447
3098
|
}
|
|
2448
3099
|
|
|
2449
3100
|
float * llama_get_embeddings(llama_context * ctx) {
|
|
@@ -2464,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|
|
2464
3115
|
return ctx->get_embeddings_seq(seq_id);
|
|
2465
3116
|
}
|
|
2466
3117
|
|
|
3118
|
+
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
|
3119
|
+
return ctx->set_sampler(seq_id, smpl);
|
|
3120
|
+
}
|
|
3121
|
+
|
|
3122
|
+
llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
|
|
3123
|
+
ctx->synchronize();
|
|
3124
|
+
|
|
3125
|
+
return ctx->get_sampled_token_ith(i);
|
|
3126
|
+
}
|
|
3127
|
+
|
|
3128
|
+
float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
|
|
3129
|
+
ctx->synchronize();
|
|
3130
|
+
|
|
3131
|
+
return ctx->get_sampled_probs_ith(i);
|
|
3132
|
+
}
|
|
3133
|
+
|
|
3134
|
+
float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
|
|
3135
|
+
ctx->synchronize();
|
|
3136
|
+
|
|
3137
|
+
return ctx->get_sampled_logits_ith(i);
|
|
3138
|
+
}
|
|
3139
|
+
|
|
3140
|
+
llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
|
|
3141
|
+
ctx->synchronize();
|
|
3142
|
+
|
|
3143
|
+
return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
|
|
3144
|
+
}
|
|
3145
|
+
|
|
3146
|
+
uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
|
|
3147
|
+
ctx->synchronize();
|
|
3148
|
+
|
|
3149
|
+
return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
|
|
3150
|
+
}
|
|
3151
|
+
|
|
3152
|
+
uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
|
|
3153
|
+
ctx->synchronize();
|
|
3154
|
+
|
|
3155
|
+
return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
|
|
3156
|
+
}
|
|
3157
|
+
|
|
3158
|
+
uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
|
3159
|
+
ctx->synchronize();
|
|
3160
|
+
|
|
3161
|
+
return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
|
|
3162
|
+
}
|
|
3163
|
+
|
|
2467
3164
|
// llama adapter API
|
|
2468
3165
|
|
|
2469
3166
|
int32_t llama_set_adapter_lora(
|