whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
#include "llama-context.h"
|
|
2
2
|
|
|
3
|
+
#include "llama-arch.h"
|
|
3
4
|
#include "llama-impl.h"
|
|
4
5
|
#include "llama-batch.h"
|
|
5
6
|
#include "llama-io.h"
|
|
6
7
|
#include "llama-memory.h"
|
|
7
8
|
#include "llama-mmap.h"
|
|
8
9
|
#include "llama-model.h"
|
|
10
|
+
#include "llama-ext.h"
|
|
9
11
|
|
|
10
12
|
#include <cinttypes>
|
|
13
|
+
#include <cmath>
|
|
11
14
|
#include <cstring>
|
|
12
15
|
#include <limits>
|
|
13
16
|
#include <stdexcept>
|
|
@@ -20,7 +23,11 @@ llama_context::llama_context(
|
|
|
20
23
|
const llama_model & model,
|
|
21
24
|
llama_context_params params) :
|
|
22
25
|
model(model),
|
|
26
|
+
cvec(std::make_unique<llama_adapter_cvec>()),
|
|
27
|
+
loras(std::make_unique<llama_adapter_loras>()),
|
|
23
28
|
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
|
29
|
+
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
|
30
|
+
// may need to be backend-dependent
|
|
24
31
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
|
25
32
|
|
|
26
33
|
t_start_us = model.t_start_us;
|
|
@@ -56,6 +63,25 @@ llama_context::llama_context(
|
|
|
56
63
|
cparams.cb_eval = params.cb_eval;
|
|
57
64
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
58
65
|
|
|
66
|
+
// Initialize backend samplers here so they are part of the sampling graph
|
|
67
|
+
// before the reserve passes run later in this function. This avoids a later
|
|
68
|
+
// re-reserve when graph nodes change.
|
|
69
|
+
if (params.samplers != nullptr && params.n_samplers > 0) {
|
|
70
|
+
for (size_t i = 0; i < params.n_samplers; ++i) {
|
|
71
|
+
const auto & config = params.samplers[i];
|
|
72
|
+
|
|
73
|
+
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
|
|
74
|
+
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
if (set_sampler(config.seq_id, config.sampler)) {
|
|
78
|
+
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
|
79
|
+
|
|
80
|
+
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
59
85
|
auto rope_scaling_type = params.rope_scaling_type;
|
|
60
86
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
|
61
87
|
rope_scaling_type = hparams.rope_scaling_type_train;
|
|
@@ -69,6 +95,43 @@ llama_context::llama_context(
|
|
|
69
95
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
70
96
|
}
|
|
71
97
|
|
|
98
|
+
if (cparams.yarn_ext_factor != 0) {
|
|
99
|
+
static auto get_mscale = [](float scale, float mscale) {
|
|
100
|
+
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
const float factor = 1.0f / cparams.rope_freq_scale;
|
|
104
|
+
|
|
105
|
+
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
|
|
106
|
+
if (hparams.rope_yarn_log_mul != 0.0f) {
|
|
107
|
+
// note: here we assume `mscale == 1.0f`
|
|
108
|
+
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
|
|
109
|
+
float mscale = 1.0f;
|
|
110
|
+
const float mscale_all_dims = hparams.rope_yarn_log_mul;
|
|
111
|
+
|
|
112
|
+
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
|
113
|
+
// special-case DEEPSEEK v2:
|
|
114
|
+
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
|
|
115
|
+
if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
|
|
116
|
+
mscale = mscale_all_dims;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
|
120
|
+
|
|
121
|
+
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
|
|
122
|
+
__func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
|
|
123
|
+
} else {
|
|
124
|
+
cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
|
|
128
|
+
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
|
|
129
|
+
//
|
|
130
|
+
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
|
|
131
|
+
// https://github.com/ggml-org/llama.cpp/pull/17945
|
|
132
|
+
cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
|
|
133
|
+
}
|
|
134
|
+
|
|
72
135
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
73
136
|
|
|
74
137
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
@@ -86,23 +149,23 @@ llama_context::llama_context(
|
|
|
86
149
|
}
|
|
87
150
|
|
|
88
151
|
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
152
|
+
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
|
|
153
|
+
|
|
154
|
+
cparams.fused_gdn_ar = true;
|
|
155
|
+
cparams.fused_gdn_ch = true;
|
|
156
|
+
cparams.auto_fgdn = true;
|
|
89
157
|
|
|
90
158
|
// with causal attention, the batch size is limited by the context size
|
|
91
159
|
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
|
92
160
|
|
|
93
|
-
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
|
94
|
-
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
|
95
|
-
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
|
96
|
-
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
|
97
|
-
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
|
98
|
-
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
|
99
|
-
cparams.n_batch = GGML_KQ_MASK_PAD;
|
|
100
|
-
}
|
|
101
161
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
102
162
|
|
|
103
163
|
cparams.op_offload = params.op_offload;
|
|
104
164
|
cparams.kv_unified = params.kv_unified;
|
|
105
165
|
|
|
166
|
+
// initialized later
|
|
167
|
+
cparams.pipeline_parallel = false;
|
|
168
|
+
|
|
106
169
|
{
|
|
107
170
|
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
|
|
108
171
|
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
|
|
@@ -112,11 +175,28 @@ llama_context::llama_context(
|
|
|
112
175
|
}
|
|
113
176
|
}
|
|
114
177
|
|
|
115
|
-
|
|
178
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
|
179
|
+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
|
|
180
|
+
|
|
181
|
+
if (cparams.kv_unified) {
|
|
182
|
+
cparams.n_ctx_seq = cparams.n_ctx;
|
|
183
|
+
} else {
|
|
184
|
+
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
185
|
+
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
|
|
186
|
+
|
|
187
|
+
if (cparams.n_ctx_seq == 0) {
|
|
188
|
+
throw std::runtime_error("n_ctx_seq == 0");
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
|
192
|
+
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
|
193
|
+
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
|
194
|
+
}
|
|
195
|
+
}
|
|
116
196
|
|
|
117
197
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
|
118
198
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
|
119
|
-
LLAMA_LOG_INFO("%s:
|
|
199
|
+
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
|
120
200
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
|
121
201
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
|
122
202
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
|
@@ -125,14 +205,14 @@ llama_context::llama_context(
|
|
|
125
205
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
|
126
206
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
|
127
207
|
|
|
128
|
-
if (
|
|
129
|
-
LLAMA_LOG_WARN("%s:
|
|
130
|
-
__func__,
|
|
208
|
+
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
|
209
|
+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
|
210
|
+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
|
131
211
|
}
|
|
132
212
|
|
|
133
|
-
if (
|
|
134
|
-
LLAMA_LOG_WARN("%s:
|
|
135
|
-
__func__,
|
|
213
|
+
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
|
214
|
+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
|
215
|
+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
|
136
216
|
}
|
|
137
217
|
|
|
138
218
|
if (!hparams.vocab_only) {
|
|
@@ -180,7 +260,6 @@ llama_context::llama_context(
|
|
|
180
260
|
|
|
181
261
|
// graph outputs buffer
|
|
182
262
|
{
|
|
183
|
-
// resized during inference when a batch uses more outputs
|
|
184
263
|
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
|
185
264
|
throw std::runtime_error("failed to reserve initial output buffer");
|
|
186
265
|
}
|
|
@@ -208,6 +287,7 @@ llama_context::llama_context(
|
|
|
208
287
|
|
|
209
288
|
backend_buft.clear();
|
|
210
289
|
backend_ptrs.clear();
|
|
290
|
+
backend_buf_exp_size.clear();
|
|
211
291
|
|
|
212
292
|
for (auto & backend : backends) {
|
|
213
293
|
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
|
@@ -224,23 +304,17 @@ llama_context::llama_context(
|
|
|
224
304
|
|
|
225
305
|
backend_buft.push_back(buft);
|
|
226
306
|
backend_ptrs.push_back(backend.get());
|
|
307
|
+
backend_buf_exp_size.push_back(0);
|
|
227
308
|
}
|
|
228
309
|
|
|
229
310
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
|
230
311
|
|
|
231
|
-
const size_t max_nodes = this->graph_max_nodes();
|
|
232
|
-
|
|
233
|
-
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
234
|
-
|
|
235
|
-
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
|
236
|
-
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
|
237
|
-
|
|
238
312
|
// TODO: move these checks to ggml_backend_sched
|
|
239
313
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
240
314
|
bool pipeline_parallel =
|
|
241
315
|
model.n_devices() > 1 &&
|
|
242
|
-
model.
|
|
243
|
-
model.
|
|
316
|
+
model.n_gpu_layers() > model.hparams.n_layer &&
|
|
317
|
+
model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
|
|
244
318
|
cparams.offload_kqv &&
|
|
245
319
|
!model.has_tensor_overrides();
|
|
246
320
|
|
|
@@ -250,6 +324,7 @@ llama_context::llama_context(
|
|
|
250
324
|
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
251
325
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
|
252
326
|
// ignore CPU backend
|
|
327
|
+
// TODO: should we ignore ACCEL types too?
|
|
253
328
|
continue;
|
|
254
329
|
}
|
|
255
330
|
auto * dev = ggml_backend_get_device(backend.get());
|
|
@@ -263,146 +338,308 @@ llama_context::llama_context(
|
|
|
263
338
|
}
|
|
264
339
|
}
|
|
265
340
|
|
|
266
|
-
|
|
341
|
+
cparams.pipeline_parallel = pipeline_parallel;
|
|
267
342
|
|
|
268
|
-
if (pipeline_parallel) {
|
|
269
|
-
LLAMA_LOG_INFO("%s: pipeline parallelism enabled
|
|
343
|
+
if (cparams.pipeline_parallel) {
|
|
344
|
+
LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
|
|
345
|
+
|
|
346
|
+
if (!graph_reuse_disable) {
|
|
347
|
+
// TODO: figure out a way to make graph reuse work with pipeline parallelism
|
|
348
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/20463
|
|
349
|
+
LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__);
|
|
350
|
+
|
|
351
|
+
graph_reuse_disable = true;
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
sched_reserve();
|
|
356
|
+
|
|
357
|
+
if (!cparams.flash_attn) {
|
|
358
|
+
if (ggml_is_quantized(params.type_v)) {
|
|
359
|
+
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
360
|
+
}
|
|
270
361
|
}
|
|
271
362
|
}
|
|
272
363
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
364
|
+
// Initialize the full vocabulary token ids for backend samplers.
|
|
365
|
+
{
|
|
366
|
+
const int n_vocab = model.vocab.n_tokens();
|
|
367
|
+
|
|
368
|
+
sampling.token_ids_full_vocab.resize(n_vocab);
|
|
369
|
+
for (int i = 0; i < n_vocab; ++i) {
|
|
370
|
+
sampling.token_ids_full_vocab[i] = i;
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
llama_context::~llama_context() {
|
|
376
|
+
if (!model.hparams.no_alloc) {
|
|
377
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
378
|
+
ggml_backend_t backend = backend_ptrs[i];
|
|
379
|
+
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
380
|
+
|
|
381
|
+
const size_t size_exp = backend_buf_exp_size[i];
|
|
382
|
+
const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
383
|
+
if (size_exp == size_act) {
|
|
384
|
+
LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
|
|
385
|
+
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
386
|
+
} else {
|
|
387
|
+
LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
|
388
|
+
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
280
389
|
}
|
|
281
390
|
}
|
|
391
|
+
}
|
|
392
|
+
ggml_opt_free(opt_ctx);
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
void llama_context::sched_reserve() {
|
|
396
|
+
if (!sched_need_reserve) {
|
|
397
|
+
return;
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
sched_need_reserve = false;
|
|
401
|
+
|
|
402
|
+
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
|
|
403
|
+
|
|
404
|
+
synchronize();
|
|
405
|
+
|
|
406
|
+
const int64_t t_start_us = ggml_time_us();
|
|
407
|
+
|
|
408
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
409
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
410
|
+
|
|
411
|
+
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
|
412
|
+
|
|
413
|
+
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
414
|
+
|
|
415
|
+
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
|
416
|
+
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
|
417
|
+
|
|
418
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
|
|
419
|
+
|
|
420
|
+
llama_memory_context_ptr mctx;
|
|
421
|
+
if (memory) {
|
|
422
|
+
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
|
423
|
+
mctx = memory->init_full();
|
|
424
|
+
if (!mctx) {
|
|
425
|
+
throw std::runtime_error("failed to initialize memory module");
|
|
426
|
+
}
|
|
427
|
+
}
|
|
282
428
|
|
|
283
|
-
|
|
429
|
+
// avoid reserving graphs with zero outputs - assume one output per sequence
|
|
430
|
+
const int n_outputs = n_seqs;
|
|
284
431
|
|
|
285
|
-
|
|
286
|
-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
432
|
+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
287
433
|
|
|
288
|
-
|
|
289
|
-
|
|
434
|
+
// resolve automatic Flash Attention use
|
|
435
|
+
if (cparams.auto_fa) {
|
|
436
|
+
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
|
437
|
+
if (!gf) {
|
|
438
|
+
throw std::runtime_error("failed to reserve graph for Flash Attention check");
|
|
439
|
+
}
|
|
290
440
|
|
|
291
|
-
|
|
441
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
|
442
|
+
bool fa_device_mismatch = false;
|
|
443
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
444
|
+
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
445
|
+
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
|
446
|
+
continue;
|
|
447
|
+
}
|
|
448
|
+
ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
449
|
+
|
|
450
|
+
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
|
451
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
452
|
+
const int il = std::stoi(n->name + prefix_len);
|
|
453
|
+
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
454
|
+
if (device_fa != device_kv) {
|
|
455
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
|
456
|
+
"is assigned to device %s (usually due to missing support)\n",
|
|
457
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
|
458
|
+
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
|
459
|
+
fa_device_mismatch = true;
|
|
460
|
+
break;
|
|
461
|
+
}
|
|
462
|
+
}
|
|
292
463
|
|
|
293
|
-
|
|
294
|
-
|
|
464
|
+
if (fa_device_mismatch) {
|
|
465
|
+
cparams.flash_attn = false;
|
|
466
|
+
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
|
467
|
+
} else {
|
|
468
|
+
cparams.flash_attn = true;
|
|
469
|
+
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
cparams.auto_fa = false;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
if (cparams.auto_fgdn) {
|
|
476
|
+
LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
|
|
477
|
+
|
|
478
|
+
if (cparams.fused_gdn_ar) {
|
|
295
479
|
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
|
296
480
|
if (!gf) {
|
|
297
|
-
throw std::runtime_error("failed to
|
|
481
|
+
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
|
|
298
482
|
}
|
|
299
483
|
|
|
300
|
-
const size_t prefix_len = strlen(
|
|
301
|
-
bool
|
|
484
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
|
|
485
|
+
bool gdn_device_mismatch = false;
|
|
302
486
|
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
303
487
|
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
304
|
-
if (n->op !=
|
|
488
|
+
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
|
305
489
|
continue;
|
|
306
490
|
}
|
|
307
|
-
ggml_backend_dev_t
|
|
308
|
-
ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
491
|
+
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
309
492
|
|
|
310
|
-
|
|
311
|
-
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
493
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
|
|
312
494
|
const int il = std::stoi(n->name + prefix_len);
|
|
313
495
|
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
314
|
-
if (
|
|
315
|
-
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
fa_device_mismatch = true;
|
|
496
|
+
if (device_gdn != device_kv) {
|
|
497
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
|
498
|
+
"is assigned to device %s (usually due to missing support)\n",
|
|
499
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
|
500
|
+
gdn_device_mismatch = true;
|
|
320
501
|
break;
|
|
321
502
|
}
|
|
322
503
|
}
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
328
|
-
}
|
|
504
|
+
|
|
505
|
+
if (gdn_device_mismatch) {
|
|
506
|
+
cparams.fused_gdn_ar = false;
|
|
507
|
+
LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
|
|
329
508
|
} else {
|
|
330
|
-
|
|
331
|
-
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
509
|
+
LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
|
|
332
510
|
}
|
|
333
511
|
}
|
|
334
512
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
{
|
|
344
|
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
513
|
+
if (cparams.fused_gdn_ch) {
|
|
514
|
+
// more than one token in the batch per sequence in order to take the chunked path
|
|
515
|
+
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
|
|
516
|
+
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
|
|
517
|
+
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
|
|
518
|
+
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
|
|
519
|
+
const uint32_t n_tokens_ch = 16*n_seqs;
|
|
520
|
+
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
|
|
345
521
|
if (!gf) {
|
|
346
|
-
throw std::runtime_error("failed to
|
|
522
|
+
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
|
|
347
523
|
}
|
|
348
524
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
525
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
|
|
526
|
+
bool gdn_device_mismatch = false;
|
|
527
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
528
|
+
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
529
|
+
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
|
530
|
+
continue;
|
|
531
|
+
}
|
|
532
|
+
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
352
533
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
534
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
|
|
535
|
+
const int il = std::stoi(n->name + prefix_len);
|
|
536
|
+
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
537
|
+
if (device_gdn != device_kv) {
|
|
538
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
|
539
|
+
"is assigned to device %s (usually due to missing support)\n",
|
|
540
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
|
541
|
+
gdn_device_mismatch = true;
|
|
542
|
+
break;
|
|
543
|
+
}
|
|
358
544
|
}
|
|
359
545
|
|
|
360
|
-
|
|
361
|
-
|
|
546
|
+
if (gdn_device_mismatch) {
|
|
547
|
+
cparams.fused_gdn_ch = false;
|
|
548
|
+
LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
|
|
549
|
+
} else {
|
|
550
|
+
LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
|
|
551
|
+
}
|
|
362
552
|
}
|
|
363
553
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
554
|
+
cparams.auto_fgdn = false;
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
// reserve worst-case graph
|
|
558
|
+
int n_splits_pp = -1;
|
|
559
|
+
int n_nodes_pp = -1;
|
|
560
|
+
|
|
561
|
+
int n_splits_tg = -1;
|
|
562
|
+
int n_nodes_tg = -1;
|
|
563
|
+
|
|
564
|
+
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
|
565
|
+
{
|
|
566
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
|
567
|
+
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
|
568
|
+
if (!gf) {
|
|
569
|
+
if (cparams.pipeline_parallel) {
|
|
570
|
+
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
|
571
|
+
cparams.pipeline_parallel = false;
|
|
572
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
|
|
573
|
+
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
574
|
+
}
|
|
371
575
|
if (!gf) {
|
|
372
576
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
373
577
|
}
|
|
374
578
|
}
|
|
375
579
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
580
|
+
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
581
|
+
n_nodes_pp = ggml_graph_n_nodes(gf);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
// reserve with tg (token generation) graph to get the number of splits and nodes
|
|
585
|
+
{
|
|
586
|
+
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
|
587
|
+
if (!gf) {
|
|
588
|
+
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
385
589
|
}
|
|
386
590
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
591
|
+
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
|
592
|
+
n_nodes_tg = ggml_graph_n_nodes(gf);
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
596
|
+
{
|
|
597
|
+
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
|
598
|
+
//
|
|
599
|
+
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
|
600
|
+
//
|
|
601
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
|
|
602
|
+
if (!gf) {
|
|
603
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
391
604
|
}
|
|
605
|
+
}
|
|
392
606
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
607
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
608
|
+
ggml_backend_t backend = backend_ptrs[i];
|
|
609
|
+
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
610
|
+
if (!model.hparams.no_alloc) {
|
|
611
|
+
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
612
|
+
}
|
|
613
|
+
if (backend_buf_exp_size[i] > 1) {
|
|
614
|
+
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
615
|
+
ggml_backend_buft_name(buft),
|
|
616
|
+
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
|
397
617
|
}
|
|
398
618
|
}
|
|
399
|
-
}
|
|
400
619
|
|
|
401
|
-
|
|
402
|
-
|
|
620
|
+
if (n_nodes_pp == n_nodes_tg) {
|
|
621
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
|
622
|
+
} else {
|
|
623
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
if (n_splits_pp == n_splits_tg) {
|
|
627
|
+
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
|
628
|
+
} else {
|
|
629
|
+
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
const int64_t t_end_us = ggml_time_us();
|
|
633
|
+
|
|
634
|
+
LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n",
|
|
635
|
+
__func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get()));
|
|
403
636
|
}
|
|
404
637
|
|
|
405
638
|
void llama_context::synchronize() {
|
|
639
|
+
if (!sched) {
|
|
640
|
+
return;
|
|
641
|
+
}
|
|
642
|
+
|
|
406
643
|
ggml_backend_sched_synchronize(sched.get());
|
|
407
644
|
|
|
408
645
|
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
|
@@ -448,8 +685,8 @@ uint32_t llama_context::n_ctx() const {
|
|
|
448
685
|
return cparams.n_ctx;
|
|
449
686
|
}
|
|
450
687
|
|
|
451
|
-
uint32_t llama_context::
|
|
452
|
-
return cparams.
|
|
688
|
+
uint32_t llama_context::n_ctx_seq() const {
|
|
689
|
+
return cparams.n_ctx_seq;
|
|
453
690
|
}
|
|
454
691
|
|
|
455
692
|
uint32_t llama_context::n_batch() const {
|
|
@@ -518,7 +755,7 @@ bool llama_context::memory_update(bool optimize) {
|
|
|
518
755
|
throw std::runtime_error("failed to initialize memory context");
|
|
519
756
|
}
|
|
520
757
|
|
|
521
|
-
const uint32_t n_seqs = cparams.
|
|
758
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
522
759
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
523
760
|
|
|
524
761
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
@@ -537,39 +774,48 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
|
|
537
774
|
float * llama_context::get_logits() {
|
|
538
775
|
output_reorder();
|
|
539
776
|
|
|
540
|
-
return logits;
|
|
777
|
+
return logits.data;
|
|
541
778
|
}
|
|
542
779
|
|
|
543
|
-
|
|
780
|
+
int64_t llama_context::output_resolve_row(int32_t i) const {
|
|
544
781
|
int64_t j = -1;
|
|
545
782
|
|
|
783
|
+
// support negative indices (last output row)
|
|
784
|
+
if (i < 0) {
|
|
785
|
+
j = n_outputs + i;
|
|
786
|
+
if (j < 0) {
|
|
787
|
+
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
788
|
+
}
|
|
789
|
+
} else if ((size_t) i >= output_ids.size()) {
|
|
790
|
+
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
791
|
+
} else {
|
|
792
|
+
// use output_ids to translate the batch token index into a row number
|
|
793
|
+
// that holds this token's data.
|
|
794
|
+
j = output_ids[i];
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
if (j < 0) {
|
|
798
|
+
// the batch token was not configured to output anything
|
|
799
|
+
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
if (j >= n_outputs) {
|
|
803
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
return j;
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
float * llama_context::get_logits_ith(int32_t i) {
|
|
546
810
|
output_reorder();
|
|
547
811
|
|
|
548
812
|
try {
|
|
549
|
-
if (logits == nullptr) {
|
|
813
|
+
if (logits.data == nullptr) {
|
|
550
814
|
throw std::runtime_error("no logits");
|
|
551
815
|
}
|
|
552
816
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
if (j < 0) {
|
|
556
|
-
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
557
|
-
}
|
|
558
|
-
} else if ((size_t) i >= output_ids.size()) {
|
|
559
|
-
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
560
|
-
} else {
|
|
561
|
-
j = output_ids[i];
|
|
562
|
-
}
|
|
563
|
-
|
|
564
|
-
if (j < 0) {
|
|
565
|
-
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
566
|
-
}
|
|
567
|
-
if (j >= n_outputs) {
|
|
568
|
-
// This should not happen
|
|
569
|
-
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
570
|
-
}
|
|
571
|
-
|
|
572
|
-
return logits + j*model.vocab.n_tokens();
|
|
817
|
+
const int64_t j = output_resolve_row(i);
|
|
818
|
+
return logits.data + j*model.vocab.n_tokens();
|
|
573
819
|
} catch (const std::exception & err) {
|
|
574
820
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
|
575
821
|
#ifndef NDEBUG
|
|
@@ -583,39 +829,24 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
583
829
|
float * llama_context::get_embeddings() {
|
|
584
830
|
output_reorder();
|
|
585
831
|
|
|
586
|
-
return embd;
|
|
832
|
+
return embd.data;
|
|
587
833
|
}
|
|
588
834
|
|
|
589
|
-
|
|
590
|
-
|
|
835
|
+
llama_token * llama_context::get_sampled_tokens() const{
|
|
836
|
+
return sampling.sampled.data;
|
|
837
|
+
}
|
|
591
838
|
|
|
839
|
+
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
592
840
|
output_reorder();
|
|
593
841
|
|
|
594
842
|
try {
|
|
595
|
-
if (embd == nullptr) {
|
|
843
|
+
if (embd.data == nullptr) {
|
|
596
844
|
throw std::runtime_error("no embeddings");
|
|
597
845
|
}
|
|
598
846
|
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
603
|
-
}
|
|
604
|
-
} else if ((size_t) i >= output_ids.size()) {
|
|
605
|
-
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
606
|
-
} else {
|
|
607
|
-
j = output_ids[i];
|
|
608
|
-
}
|
|
609
|
-
|
|
610
|
-
if (j < 0) {
|
|
611
|
-
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
612
|
-
}
|
|
613
|
-
if (j >= n_outputs) {
|
|
614
|
-
// This should not happen
|
|
615
|
-
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
616
|
-
}
|
|
617
|
-
|
|
618
|
-
return embd + j*model.hparams.n_embd;
|
|
847
|
+
const int64_t j = output_resolve_row(i);
|
|
848
|
+
const uint32_t n_embd_out = model.hparams.n_embd_out();
|
|
849
|
+
return embd.data + j*n_embd_out;
|
|
619
850
|
} catch (const std::exception & err) {
|
|
620
851
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
621
852
|
#ifndef NDEBUG
|
|
@@ -635,6 +866,137 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
635
866
|
return it->second.data();
|
|
636
867
|
}
|
|
637
868
|
|
|
869
|
+
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
|
870
|
+
output_reorder();
|
|
871
|
+
|
|
872
|
+
if (!sampling.sampled.has_data()) {
|
|
873
|
+
return LLAMA_TOKEN_NULL;
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
try {
|
|
877
|
+
const int64_t row = output_resolve_row(idx);
|
|
878
|
+
GGML_ASSERT(row < (int64_t) sampling.sampled.size);
|
|
879
|
+
return sampling.sampled.data[row];
|
|
880
|
+
} catch (const std::exception & err) {
|
|
881
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
|
882
|
+
return LLAMA_TOKEN_NULL;
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|
887
|
+
output_reorder();
|
|
888
|
+
|
|
889
|
+
if (!sampling.probs.has_data()) {
|
|
890
|
+
return nullptr;
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
try {
|
|
894
|
+
const int64_t row = output_resolve_row(idx);
|
|
895
|
+
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
|
896
|
+
return nullptr;
|
|
897
|
+
}
|
|
898
|
+
return sampling.probs.data + row*model.vocab.n_tokens();
|
|
899
|
+
} catch (const std::exception & err) {
|
|
900
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
|
901
|
+
return nullptr;
|
|
902
|
+
}
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
|
906
|
+
output_reorder();
|
|
907
|
+
|
|
908
|
+
if (!sampling.logits.has_data()) {
|
|
909
|
+
return nullptr;
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
try {
|
|
913
|
+
const int64_t row = output_resolve_row(idx);
|
|
914
|
+
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
|
915
|
+
return nullptr;
|
|
916
|
+
}
|
|
917
|
+
return sampling.logits.data + row*model.vocab.n_tokens();
|
|
918
|
+
} catch (const std::exception & err) {
|
|
919
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
|
920
|
+
return nullptr;
|
|
921
|
+
}
|
|
922
|
+
}
|
|
923
|
+
|
|
924
|
+
const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|
925
|
+
output_reorder();
|
|
926
|
+
|
|
927
|
+
try {
|
|
928
|
+
const int64_t row = output_resolve_row(idx);
|
|
929
|
+
if (sampling.candidates.has_data() &&
|
|
930
|
+
(size_t) row < sampling.candidates_count.size() &&
|
|
931
|
+
sampling.candidates_count[row] > 0) {
|
|
932
|
+
return sampling.candidates.data + row*model.vocab.n_tokens();
|
|
933
|
+
}
|
|
934
|
+
} catch (const std::exception & err) {
|
|
935
|
+
// fallback to full vocab list
|
|
936
|
+
GGML_UNUSED(err);
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
return sampling.token_ids_full_vocab.data();
|
|
940
|
+
}
|
|
941
|
+
|
|
942
|
+
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
|
943
|
+
output_reorder();
|
|
944
|
+
|
|
945
|
+
if (!sampling.candidates.has_data()) {
|
|
946
|
+
return 0;
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
try {
|
|
950
|
+
const int64_t row = output_resolve_row(idx);
|
|
951
|
+
if ((size_t) row >= sampling.candidates_count.size()) {
|
|
952
|
+
return 0;
|
|
953
|
+
}
|
|
954
|
+
return sampling.candidates_count[row];
|
|
955
|
+
} catch (const std::exception & err) {
|
|
956
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
|
|
957
|
+
return 0;
|
|
958
|
+
}
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|
962
|
+
output_reorder();
|
|
963
|
+
|
|
964
|
+
if (!sampling.logits.has_data()) {
|
|
965
|
+
return model.vocab.n_tokens();
|
|
966
|
+
}
|
|
967
|
+
|
|
968
|
+
try {
|
|
969
|
+
const int64_t row = output_resolve_row(idx);
|
|
970
|
+
if ((size_t) row >= sampling.logits_count.size()) {
|
|
971
|
+
return 0;
|
|
972
|
+
}
|
|
973
|
+
return sampling.logits_count[row];
|
|
974
|
+
} catch (const std::exception & err) {
|
|
975
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
|
|
976
|
+
return 0;
|
|
977
|
+
}
|
|
978
|
+
}
|
|
979
|
+
|
|
980
|
+
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
|
981
|
+
output_reorder();
|
|
982
|
+
|
|
983
|
+
if (!sampling.probs.has_data()) {
|
|
984
|
+
return 0;
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
try {
|
|
988
|
+
const int64_t row = output_resolve_row(idx);
|
|
989
|
+
if ((size_t) row >= sampling.probs_count.size()) {
|
|
990
|
+
return 0;
|
|
991
|
+
}
|
|
992
|
+
return sampling.probs_count[row];
|
|
993
|
+
} catch (const std::exception & err) {
|
|
994
|
+
LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
|
|
995
|
+
return 0;
|
|
996
|
+
}
|
|
997
|
+
}
|
|
998
|
+
|
|
999
|
+
|
|
638
1000
|
void llama_context::attach_threadpool(
|
|
639
1001
|
ggml_threadpool_t threadpool,
|
|
640
1002
|
ggml_threadpool_t threadpool_batch) {
|
|
@@ -671,54 +1033,131 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
|
|
|
671
1033
|
set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
|
|
672
1034
|
}
|
|
673
1035
|
}
|
|
674
|
-
}
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
void llama_context::set_embeddings(bool value) {
|
|
1039
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
1040
|
+
|
|
1041
|
+
cparams.embeddings = value;
|
|
1042
|
+
|
|
1043
|
+
// TODO: not sure yet if we want to reserve here
|
|
1044
|
+
//sched_need_reserve = true;
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
void llama_context::set_causal_attn(bool value) {
|
|
1048
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
1049
|
+
|
|
1050
|
+
if (cparams.causal_attn == value) {
|
|
1051
|
+
return;
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
cparams.causal_attn = value;
|
|
1055
|
+
|
|
1056
|
+
sched_need_reserve = true;
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
void llama_context::set_warmup(bool value) {
|
|
1060
|
+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
1061
|
+
|
|
1062
|
+
if (cparams.warmup == value) {
|
|
1063
|
+
return;
|
|
1064
|
+
}
|
|
1065
|
+
|
|
1066
|
+
cparams.warmup = value;
|
|
1067
|
+
|
|
1068
|
+
// warmups are usually with small batches, so no need to reserve
|
|
1069
|
+
//sched_need_reserve = true;
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|
1073
|
+
if (!sampler && sampling.samplers.count(seq_id) == 0) {
|
|
1074
|
+
return true;
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
|
1078
|
+
|
|
1079
|
+
const bool can_offload =
|
|
1080
|
+
sampler &&
|
|
1081
|
+
sampler->iface->backend_init &&
|
|
1082
|
+
sampler->iface->backend_apply &&
|
|
1083
|
+
llama_sampler_chain_n(sampler) > 0;
|
|
1084
|
+
|
|
1085
|
+
if (sampler && can_offload) {
|
|
1086
|
+
auto * buft = ggml_backend_dev_buffer_type(model.dev_output());
|
|
1087
|
+
|
|
1088
|
+
sampler->iface->backend_init(sampler, buft);
|
|
1089
|
+
|
|
1090
|
+
sampling.samplers[seq_id] = sampler;
|
|
1091
|
+
|
|
1092
|
+
sched_need_reserve = true;
|
|
1093
|
+
|
|
1094
|
+
return true;
|
|
1095
|
+
}
|
|
675
1096
|
|
|
676
|
-
|
|
677
|
-
|
|
1097
|
+
if (sampler && !can_offload) {
|
|
1098
|
+
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
|
678
1099
|
|
|
679
|
-
|
|
680
|
-
|
|
1100
|
+
if (sampling.samplers.count(seq_id) > 0) {
|
|
1101
|
+
sched_need_reserve = true;
|
|
1102
|
+
}
|
|
681
1103
|
|
|
682
|
-
|
|
683
|
-
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
1104
|
+
sampling.samplers.erase(seq_id);
|
|
684
1105
|
|
|
685
|
-
|
|
686
|
-
}
|
|
1106
|
+
return false;
|
|
1107
|
+
}
|
|
687
1108
|
|
|
688
|
-
|
|
689
|
-
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
1109
|
+
sampling.samplers.erase(seq_id);
|
|
690
1110
|
|
|
691
|
-
|
|
1111
|
+
sched_need_reserve = true;
|
|
1112
|
+
|
|
1113
|
+
return true;
|
|
692
1114
|
}
|
|
693
1115
|
|
|
694
|
-
void llama_context::
|
|
695
|
-
|
|
696
|
-
float scale) {
|
|
697
|
-
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
|
|
1116
|
+
void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
|
1117
|
+
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
|
698
1118
|
|
|
699
|
-
|
|
700
|
-
|
|
1119
|
+
if (adapters_lora_are_same(adapters, n_adapters, scales)) {
|
|
1120
|
+
return;
|
|
1121
|
+
}
|
|
701
1122
|
|
|
702
|
-
|
|
703
|
-
llama_adapter_lora * adapter) {
|
|
704
|
-
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
|
|
1123
|
+
loras.reset(new llama_adapter_loras());
|
|
705
1124
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
1125
|
+
for (size_t i = 0; i < n_adapters; i ++) {
|
|
1126
|
+
if (scales[i] != 0.0f) {
|
|
1127
|
+
loras->insert({adapters[i], scales[i]});
|
|
1128
|
+
}
|
|
710
1129
|
}
|
|
711
1130
|
|
|
712
|
-
|
|
1131
|
+
sched_need_reserve = true;
|
|
713
1132
|
}
|
|
714
1133
|
|
|
715
|
-
|
|
716
|
-
LLAMA_LOG_DEBUG("%s:
|
|
1134
|
+
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
|
1135
|
+
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
|
1136
|
+
|
|
1137
|
+
// Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison.
|
|
1138
|
+
size_t n_non_zero = 0;
|
|
1139
|
+
|
|
1140
|
+
for (size_t i = 0; i < n_adapters; i ++) {
|
|
1141
|
+
if (scales[i] == 0.0f) {
|
|
1142
|
+
continue;
|
|
1143
|
+
}
|
|
1144
|
+
n_non_zero++;
|
|
1145
|
+
|
|
1146
|
+
auto it = loras->find(adapters[i]);
|
|
1147
|
+
|
|
1148
|
+
if (it == loras->end() || it->second != scales[i]) {
|
|
1149
|
+
return false;
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
if (n_non_zero != loras->size()) {
|
|
1154
|
+
return false;
|
|
1155
|
+
}
|
|
717
1156
|
|
|
718
|
-
|
|
1157
|
+
return true;
|
|
719
1158
|
}
|
|
720
1159
|
|
|
721
|
-
bool llama_context::
|
|
1160
|
+
bool llama_context::set_adapter_cvec(
|
|
722
1161
|
const float * data,
|
|
723
1162
|
size_t len,
|
|
724
1163
|
int32_t n_embd,
|
|
@@ -726,7 +1165,9 @@ bool llama_context::apply_adapter_cvec(
|
|
|
726
1165
|
int32_t il_end) {
|
|
727
1166
|
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
|
|
728
1167
|
|
|
729
|
-
|
|
1168
|
+
// TODO: should we reserve?
|
|
1169
|
+
|
|
1170
|
+
return cvec->apply(model, data, len, n_embd, il_start, il_end);
|
|
730
1171
|
}
|
|
731
1172
|
|
|
732
1173
|
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
|
@@ -776,6 +1217,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|
|
776
1217
|
{
|
|
777
1218
|
//const auto t_start_us = ggml_time_us();
|
|
778
1219
|
|
|
1220
|
+
// FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated
|
|
779
1221
|
res->set_inputs(&ubatch);
|
|
780
1222
|
|
|
781
1223
|
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
|
@@ -803,7 +1245,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
803
1245
|
|
|
804
1246
|
const auto & hparams = model.hparams;
|
|
805
1247
|
|
|
806
|
-
const int64_t n_embd = hparams.
|
|
1248
|
+
const int64_t n_embd = hparams.n_embd_inp();
|
|
807
1249
|
const int64_t n_vocab = model.vocab.n_tokens();
|
|
808
1250
|
|
|
809
1251
|
// note: during encode, we always pass the full sequence starting from pos = 0
|
|
@@ -828,6 +1270,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
828
1270
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
829
1271
|
embd_seq.clear();
|
|
830
1272
|
|
|
1273
|
+
sched_reserve();
|
|
1274
|
+
|
|
831
1275
|
n_queued_tokens += n_tokens;
|
|
832
1276
|
|
|
833
1277
|
// reserve output buffer
|
|
@@ -867,16 +1311,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
867
1311
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
868
1312
|
|
|
869
1313
|
// extract logits
|
|
870
|
-
|
|
1314
|
+
if (logits.data && t_logits) {
|
|
871
1315
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
872
1316
|
GGML_ASSERT(backend_res != nullptr);
|
|
873
|
-
GGML_ASSERT(logits != nullptr);
|
|
1317
|
+
GGML_ASSERT(logits.data != nullptr);
|
|
874
1318
|
|
|
875
|
-
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
|
|
1319
|
+
ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
|
|
876
1320
|
}
|
|
877
1321
|
|
|
878
1322
|
// extract embeddings
|
|
879
|
-
if (embd && t_embd) {
|
|
1323
|
+
if (embd.data && t_embd) {
|
|
880
1324
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
881
1325
|
GGML_ASSERT(backend_embd != nullptr);
|
|
882
1326
|
|
|
@@ -884,10 +1328,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
884
1328
|
case LLAMA_POOLING_TYPE_NONE:
|
|
885
1329
|
{
|
|
886
1330
|
// extract token embeddings
|
|
887
|
-
GGML_ASSERT(embd != nullptr);
|
|
1331
|
+
GGML_ASSERT(embd.data != nullptr);
|
|
1332
|
+
const uint32_t n_embd_out = hparams.n_embd_out();
|
|
888
1333
|
|
|
889
|
-
GGML_ASSERT(n_tokens*
|
|
890
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*
|
|
1334
|
+
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
|
|
1335
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
|
|
891
1336
|
} break;
|
|
892
1337
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
893
1338
|
case LLAMA_POOLING_TYPE_CLS:
|
|
@@ -935,7 +1380,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
935
1380
|
cross.n_embd = t_embd->ne[0];
|
|
936
1381
|
cross.n_enc = t_embd->ne[1];
|
|
937
1382
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
|
938
|
-
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
|
1383
|
+
memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
|
|
939
1384
|
|
|
940
1385
|
const auto & batch = balloc->get_batch();
|
|
941
1386
|
|
|
@@ -955,6 +1400,128 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
955
1400
|
return 0;
|
|
956
1401
|
}
|
|
957
1402
|
|
|
1403
|
+
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
|
1404
|
+
std::map<llama_seq_id, uint32_t> seq_to_row;
|
|
1405
|
+
// how many output tokens we have seen so far for this ubatch.
|
|
1406
|
+
uint32_t local = 0;
|
|
1407
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
1408
|
+
// skip tokens that are not output.
|
|
1409
|
+
if (!ubatch.output[i]) {
|
|
1410
|
+
continue;
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1413
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
1414
|
+
// row_offset is the number of output tokens before this ubatch.
|
|
1415
|
+
seq_to_row[seq_id] = row_offset + local;
|
|
1416
|
+
++local;
|
|
1417
|
+
}
|
|
1418
|
+
return seq_to_row;
|
|
1419
|
+
}
|
|
1420
|
+
|
|
1421
|
+
static void copy_tensor_async_ints(
|
|
1422
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1423
|
+
const buffer_view<llama_token> & sampled,
|
|
1424
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1425
|
+
ggml_backend_sched_t sched) {
|
|
1426
|
+
if (!sampled.has_data()) {
|
|
1427
|
+
return;
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1430
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1431
|
+
auto it = seq_to_row.find(seq_id);
|
|
1432
|
+
if (it == seq_to_row.end()) {
|
|
1433
|
+
continue;
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
const uint32_t row = it->second;
|
|
1437
|
+
GGML_ASSERT(row < sampled.size);
|
|
1438
|
+
|
|
1439
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
|
1440
|
+
|
|
1441
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1442
|
+
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
|
|
1443
|
+
}
|
|
1444
|
+
}
|
|
1445
|
+
|
|
1446
|
+
static void copy_tensor_async_floats(
|
|
1447
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1448
|
+
const buffer_view<float> & dst,
|
|
1449
|
+
size_t stride,
|
|
1450
|
+
std::vector<uint32_t> & counts,
|
|
1451
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1452
|
+
ggml_backend_sched_t sched) {
|
|
1453
|
+
if (!dst.has_data()) {
|
|
1454
|
+
return;
|
|
1455
|
+
}
|
|
1456
|
+
|
|
1457
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1458
|
+
auto it = seq_to_row.find(seq_id);
|
|
1459
|
+
if (it == seq_to_row.end()) {
|
|
1460
|
+
continue;
|
|
1461
|
+
}
|
|
1462
|
+
|
|
1463
|
+
const uint32_t row = it->second;
|
|
1464
|
+
GGML_ASSERT(row < counts.size());
|
|
1465
|
+
|
|
1466
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
|
1467
|
+
|
|
1468
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1469
|
+
float * row_ptr = dst.data + (size_t) row * stride;
|
|
1470
|
+
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1471
|
+
|
|
1472
|
+
// Update the actual number of logits/probabilities that were written for this row.
|
|
1473
|
+
counts[row] = ggml_nelements(tensor);
|
|
1474
|
+
}
|
|
1475
|
+
}
|
|
1476
|
+
|
|
1477
|
+
static void copy_tensor_async_candidates(
|
|
1478
|
+
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1479
|
+
const buffer_view<llama_token> & dst,
|
|
1480
|
+
size_t stride,
|
|
1481
|
+
std::vector<uint32_t> & counts,
|
|
1482
|
+
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1483
|
+
ggml_backend_sched_t sched) {
|
|
1484
|
+
if (!dst.has_data()) {
|
|
1485
|
+
return;
|
|
1486
|
+
}
|
|
1487
|
+
|
|
1488
|
+
for (const auto & [seq_id, tensor] : tensor_map) {
|
|
1489
|
+
auto it = seq_to_row.find(seq_id);
|
|
1490
|
+
if (it == seq_to_row.end()) {
|
|
1491
|
+
continue;
|
|
1492
|
+
}
|
|
1493
|
+
|
|
1494
|
+
const uint32_t row = it->second;
|
|
1495
|
+
GGML_ASSERT(row < counts.size());
|
|
1496
|
+
|
|
1497
|
+
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
|
1498
|
+
|
|
1499
|
+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1500
|
+
llama_token * row_ptr = dst.data + (size_t) row * stride;
|
|
1501
|
+
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1502
|
+
|
|
1503
|
+
// Update the actual number of candidates that were written.
|
|
1504
|
+
counts[row] = ggml_nelements(tensor);
|
|
1505
|
+
}
|
|
1506
|
+
}
|
|
1507
|
+
|
|
1508
|
+
static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
|
|
1509
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
1510
|
+
if (!ubatch.output[i]) {
|
|
1511
|
+
continue;
|
|
1512
|
+
}
|
|
1513
|
+
|
|
1514
|
+
// Check if the output token has at least one sequence without a backend sampler.
|
|
1515
|
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
|
1516
|
+
llama_seq_id seq_id = ubatch.seq_id[i][j];
|
|
1517
|
+
if (samplers.find(seq_id) == samplers.end()) {
|
|
1518
|
+
return true;
|
|
1519
|
+
}
|
|
1520
|
+
}
|
|
1521
|
+
}
|
|
1522
|
+
return false; // all sequences use backend sampling
|
|
1523
|
+
}
|
|
1524
|
+
|
|
958
1525
|
int llama_context::decode(const llama_batch & batch_inp) {
|
|
959
1526
|
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
|
960
1527
|
|
|
@@ -972,12 +1539,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
972
1539
|
const auto & hparams = model.hparams;
|
|
973
1540
|
|
|
974
1541
|
const int64_t n_vocab = vocab.n_tokens();
|
|
975
|
-
const int64_t n_embd = hparams.
|
|
1542
|
+
const int64_t n_embd = hparams.n_embd_inp();
|
|
976
1543
|
|
|
977
1544
|
// when computing embeddings, all tokens are output
|
|
978
|
-
const bool output_all
|
|
1545
|
+
const bool output_all = cparams.embeddings;
|
|
1546
|
+
const bool has_samplers = !sampling.samplers.empty();
|
|
1547
|
+
|
|
1548
|
+
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
|
|
1549
|
+
|
|
1550
|
+
// TODO: avoid this workaround in the future
|
|
1551
|
+
if (has_samplers && batch_inp.logits) {
|
|
1552
|
+
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
|
1553
|
+
|
|
1554
|
+
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
|
|
1555
|
+
if (batch_inp.logits[i] == 0) {
|
|
1556
|
+
continue;
|
|
1557
|
+
}
|
|
1558
|
+
|
|
1559
|
+
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
|
|
1560
|
+
|
|
1561
|
+
for (int32_t s = 0; s < ns; ++s) {
|
|
1562
|
+
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
|
|
1563
|
+
|
|
1564
|
+
seq_output_count[seq_id]++;
|
|
1565
|
+
if (seq_output_count[seq_id] > 1) {
|
|
1566
|
+
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
|
|
1567
|
+
__func__, seq_id, seq_output_count[seq_id]);
|
|
1568
|
+
return -1;
|
|
1569
|
+
}
|
|
1570
|
+
}
|
|
1571
|
+
}
|
|
1572
|
+
}
|
|
979
1573
|
|
|
980
|
-
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
|
1574
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
|
|
981
1575
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
982
1576
|
return -1;
|
|
983
1577
|
}
|
|
@@ -1007,6 +1601,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1007
1601
|
embd_seq.clear();
|
|
1008
1602
|
output_swaps.clear();
|
|
1009
1603
|
|
|
1604
|
+
sched_reserve();
|
|
1605
|
+
|
|
1010
1606
|
bool did_optimize = false;
|
|
1011
1607
|
|
|
1012
1608
|
// handle any pending shifts/copies
|
|
@@ -1131,22 +1727,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1131
1727
|
}
|
|
1132
1728
|
|
|
1133
1729
|
// extract logits
|
|
1134
|
-
if (t_logits && n_outputs > 0) {
|
|
1730
|
+
if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
|
|
1135
1731
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
1136
1732
|
GGML_ASSERT(backend_res != nullptr);
|
|
1137
|
-
GGML_ASSERT(logits != nullptr);
|
|
1733
|
+
GGML_ASSERT(logits.data != nullptr);
|
|
1138
1734
|
|
|
1139
|
-
float * logits_out = logits + n_outputs_prev*n_vocab;
|
|
1735
|
+
float * logits_out = logits.data + n_outputs_prev*n_vocab;
|
|
1140
1736
|
|
|
1141
1737
|
if (n_outputs) {
|
|
1142
1738
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1143
|
-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t)
|
|
1739
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
|
|
1144
1740
|
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
|
1145
1741
|
}
|
|
1146
1742
|
}
|
|
1147
1743
|
|
|
1148
1744
|
// extract embeddings
|
|
1149
|
-
if (t_embd && n_outputs > 0) {
|
|
1745
|
+
if (embd.data && t_embd && n_outputs > 0) {
|
|
1150
1746
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1151
1747
|
GGML_ASSERT(backend_embd != nullptr);
|
|
1152
1748
|
|
|
@@ -1154,13 +1750,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1154
1750
|
case LLAMA_POOLING_TYPE_NONE:
|
|
1155
1751
|
{
|
|
1156
1752
|
// extract token embeddings
|
|
1157
|
-
GGML_ASSERT(embd != nullptr);
|
|
1158
|
-
|
|
1753
|
+
GGML_ASSERT(embd.data != nullptr);
|
|
1754
|
+
const uint32_t n_embd_out = hparams.n_embd_out();
|
|
1755
|
+
float * embd_out = embd.data + n_outputs_prev*n_embd_out;
|
|
1159
1756
|
|
|
1160
1757
|
if (n_outputs) {
|
|
1161
1758
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1162
|
-
GGML_ASSERT((n_outputs_prev + n_outputs)*
|
|
1163
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*
|
|
1759
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
|
|
1760
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
|
|
1164
1761
|
}
|
|
1165
1762
|
} break;
|
|
1166
1763
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
@@ -1200,6 +1797,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1200
1797
|
}
|
|
1201
1798
|
}
|
|
1202
1799
|
|
|
1800
|
+
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
|
1801
|
+
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
|
|
1802
|
+
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
|
1803
|
+
const auto stride = n_vocab;
|
|
1804
|
+
|
|
1805
|
+
// async copy the sampling data from the backend to the host
|
|
1806
|
+
copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
|
|
1807
|
+
|
|
1808
|
+
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
|
1809
|
+
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
|
1810
|
+
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
|
|
1811
|
+
}
|
|
1812
|
+
|
|
1203
1813
|
n_outputs_prev += n_outputs;
|
|
1204
1814
|
} while (mctx->next());
|
|
1205
1815
|
|
|
@@ -1224,7 +1834,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1224
1834
|
|
|
1225
1835
|
// make the outputs have the same order they had in the user-provided batch
|
|
1226
1836
|
// note: this is mostly relevant for recurrent models atm
|
|
1227
|
-
if (!sorted_output) {
|
|
1837
|
+
if (!sorted_output && n_outputs > 1) {
|
|
1228
1838
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1229
1839
|
|
|
1230
1840
|
// TODO: is there something more efficient which also minimizes swaps?
|
|
@@ -1269,9 +1879,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1269
1879
|
|
|
1270
1880
|
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
|
|
1271
1881
|
|
|
1272
|
-
const auto n_batch
|
|
1273
|
-
const auto n_vocab
|
|
1274
|
-
const auto
|
|
1882
|
+
const auto n_batch = cparams.n_batch;
|
|
1883
|
+
const auto n_vocab = vocab.n_tokens();
|
|
1884
|
+
const auto n_embd_out = hparams.n_embd_out();
|
|
1275
1885
|
|
|
1276
1886
|
bool has_logits = true;
|
|
1277
1887
|
bool has_embd = cparams.embeddings;
|
|
@@ -1282,8 +1892,19 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1282
1892
|
has_embd = true;
|
|
1283
1893
|
}
|
|
1284
1894
|
|
|
1285
|
-
|
|
1286
|
-
|
|
1895
|
+
|
|
1896
|
+
size_t backend_float_count = 0;
|
|
1897
|
+
size_t backend_token_count = 0;
|
|
1898
|
+
|
|
1899
|
+
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
|
1900
|
+
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
|
1901
|
+
|
|
1902
|
+
// Allocate backend sampling output buffers if there are backend samplers configured.
|
|
1903
|
+
const bool has_sampling = !sampling.samplers.empty();
|
|
1904
|
+
if (has_sampling) {
|
|
1905
|
+
backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
|
|
1906
|
+
backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
|
|
1907
|
+
}
|
|
1287
1908
|
|
|
1288
1909
|
if (output_ids.empty()) {
|
|
1289
1910
|
// init, never resized afterwards
|
|
@@ -1291,7 +1912,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1291
1912
|
}
|
|
1292
1913
|
|
|
1293
1914
|
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
|
1294
|
-
const size_t new_size =
|
|
1915
|
+
const size_t new_size =
|
|
1916
|
+
(logits.size + embd.size + backend_float_count) * sizeof(float) +
|
|
1917
|
+
( backend_token_count) * sizeof(llama_token);
|
|
1295
1918
|
|
|
1296
1919
|
// alloc only when more than the current capacity is required
|
|
1297
1920
|
// TODO: also consider shrinking the buffer
|
|
@@ -1299,11 +1922,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1299
1922
|
if (buf_output) {
|
|
1300
1923
|
#ifndef NDEBUG
|
|
1301
1924
|
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
|
1302
|
-
|
|
1925
|
+
LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
1303
1926
|
#endif
|
|
1927
|
+
synchronize();
|
|
1928
|
+
|
|
1929
|
+
// TODO: not needed?
|
|
1304
1930
|
buf_output = nullptr;
|
|
1305
|
-
logits = nullptr;
|
|
1306
|
-
embd = nullptr;
|
|
1931
|
+
logits.data = nullptr;
|
|
1932
|
+
embd.data = nullptr;
|
|
1307
1933
|
}
|
|
1308
1934
|
|
|
1309
1935
|
auto * buft = ggml_backend_cpu_buffer_type();
|
|
@@ -1322,8 +1948,50 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1322
1948
|
|
|
1323
1949
|
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
|
1324
1950
|
|
|
1325
|
-
|
|
1326
|
-
|
|
1951
|
+
size_t offset = 0;
|
|
1952
|
+
uint8_t * base = (uint8_t *) output_base;
|
|
1953
|
+
|
|
1954
|
+
logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
|
|
1955
|
+
offset += logits.size * sizeof(float);
|
|
1956
|
+
|
|
1957
|
+
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
|
1958
|
+
offset += embd.size * sizeof(float);
|
|
1959
|
+
|
|
1960
|
+
if (has_sampling) {
|
|
1961
|
+
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
|
1962
|
+
offset += sampling.logits.size * sizeof(float);
|
|
1963
|
+
|
|
1964
|
+
sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
|
1965
|
+
offset += sampling.probs.size * sizeof(float);
|
|
1966
|
+
|
|
1967
|
+
sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
|
|
1968
|
+
offset += sampling.sampled.size * sizeof(llama_token);
|
|
1969
|
+
|
|
1970
|
+
sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
|
1971
|
+
offset += sampling.candidates.size * sizeof(llama_token);
|
|
1972
|
+
|
|
1973
|
+
// The count vectors keep track of the actual number of logits/probs/candidates
|
|
1974
|
+
// copied from the backend for each output row.
|
|
1975
|
+
|
|
1976
|
+
sampling.logits_count.resize(n_outputs_max);
|
|
1977
|
+
sampling.probs_count.resize(n_outputs_max);
|
|
1978
|
+
sampling.candidates_count.resize(n_outputs_max);
|
|
1979
|
+
|
|
1980
|
+
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
|
1981
|
+
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
|
1982
|
+
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
|
1983
|
+
|
|
1984
|
+
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
|
1985
|
+
} else {
|
|
1986
|
+
sampling.logits = {nullptr, 0};
|
|
1987
|
+
sampling.probs = {nullptr, 0};
|
|
1988
|
+
sampling.sampled = {nullptr, 0};
|
|
1989
|
+
sampling.candidates = {nullptr, 0};
|
|
1990
|
+
|
|
1991
|
+
sampling.logits_count.clear();
|
|
1992
|
+
sampling.probs_count.clear();
|
|
1993
|
+
sampling.candidates_count.clear();
|
|
1994
|
+
}
|
|
1327
1995
|
|
|
1328
1996
|
// set all ids as invalid (negative)
|
|
1329
1997
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
@@ -1341,16 +2009,43 @@ void llama_context::output_reorder() {
|
|
|
1341
2009
|
const uint64_t i0 = output_swaps[s].i0;
|
|
1342
2010
|
const uint64_t i1 = output_swaps[s].i1;
|
|
1343
2011
|
|
|
1344
|
-
if (
|
|
2012
|
+
if (logits.size > 0) {
|
|
1345
2013
|
for (uint64_t k = 0; k < n_vocab; k++) {
|
|
1346
|
-
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
|
|
2014
|
+
std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
|
|
1347
2015
|
}
|
|
1348
2016
|
}
|
|
1349
2017
|
|
|
1350
|
-
if (
|
|
2018
|
+
if (embd.size > 0) {
|
|
1351
2019
|
for (uint64_t k = 0; k < n_embd; k++) {
|
|
1352
|
-
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
|
2020
|
+
std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
|
|
2021
|
+
}
|
|
2022
|
+
}
|
|
2023
|
+
|
|
2024
|
+
if (!sampling.samplers.empty()) {
|
|
2025
|
+
assert(sampling.logits.size > 0);
|
|
2026
|
+
assert(sampling.probs.size > 0);
|
|
2027
|
+
assert(sampling.candidates.size > 0);
|
|
2028
|
+
assert(sampling.sampled.size > 0);
|
|
2029
|
+
assert(sampling.logits_count.size() > 0);
|
|
2030
|
+
assert(sampling.probs_count.size() > 0);
|
|
2031
|
+
assert(sampling.candidates_count.size() > 0);
|
|
2032
|
+
|
|
2033
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
2034
|
+
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
|
2035
|
+
}
|
|
2036
|
+
|
|
2037
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
2038
|
+
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
|
2039
|
+
}
|
|
2040
|
+
|
|
2041
|
+
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
2042
|
+
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
|
1353
2043
|
}
|
|
2044
|
+
|
|
2045
|
+
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
|
2046
|
+
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
|
2047
|
+
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
|
2048
|
+
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
|
1354
2049
|
}
|
|
1355
2050
|
}
|
|
1356
2051
|
|
|
@@ -1361,28 +2056,36 @@ void llama_context::output_reorder() {
|
|
|
1361
2056
|
// graph
|
|
1362
2057
|
//
|
|
1363
2058
|
|
|
1364
|
-
uint32_t llama_context::graph_max_nodes() const {
|
|
1365
|
-
|
|
2059
|
+
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
|
2060
|
+
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
|
|
2061
|
+
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
|
2062
|
+
}
|
|
2063
|
+
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
|
2064
|
+
for (const auto & lora : model.loras) {
|
|
2065
|
+
res += lora->get_n_nodes();
|
|
2066
|
+
}
|
|
2067
|
+
return res;
|
|
1366
2068
|
}
|
|
1367
2069
|
|
|
1368
2070
|
llm_graph_result * llama_context::get_gf_res_reserve() const {
|
|
1369
2071
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
|
1370
2072
|
}
|
|
1371
2073
|
|
|
1372
|
-
ggml_cgraph * llama_context::graph_reserve(
|
|
2074
|
+
ggml_cgraph * llama_context::graph_reserve(
|
|
2075
|
+
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
|
|
1373
2076
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1374
2077
|
GGML_ASSERT(n_outputs >= 1);
|
|
1375
2078
|
|
|
1376
2079
|
if (n_tokens % n_seqs != 0) {
|
|
1377
2080
|
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
|
1378
|
-
n_outputs = std::
|
|
2081
|
+
n_outputs = std::max(n_outputs, n_tokens);
|
|
1379
2082
|
|
|
1380
2083
|
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
1381
2084
|
}
|
|
1382
2085
|
|
|
1383
2086
|
ggml_backend_sched_reset(sched.get());
|
|
1384
2087
|
|
|
1385
|
-
// when the scheduler is reset, we
|
|
2088
|
+
// when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that
|
|
1386
2089
|
gf_res_prev->reset();
|
|
1387
2090
|
|
|
1388
2091
|
// store the n_outputs as it is, and restore it afterwards
|
|
@@ -1394,6 +2097,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
|
1394
2097
|
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
|
1395
2098
|
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
|
1396
2099
|
|
|
2100
|
+
// set one output token per sequence in order to activate all backend samplers
|
|
2101
|
+
std::vector<llama_seq_id> seq_ids(n_seqs);
|
|
2102
|
+
for (uint32_t i = 0; i < n_seqs; ++i) {
|
|
2103
|
+
seq_ids[i] = i;
|
|
2104
|
+
ubatch.n_seq_id[i] = 1;
|
|
2105
|
+
ubatch.seq_id[i] = &seq_ids[i];
|
|
2106
|
+
ubatch.output[i] = true;
|
|
2107
|
+
}
|
|
2108
|
+
|
|
1397
2109
|
auto * res = gf_res_reserve.get();
|
|
1398
2110
|
|
|
1399
2111
|
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
|
@@ -1406,8 +2118,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
|
1406
2118
|
|
|
1407
2119
|
// initialize scheduler with the specified graph
|
|
1408
2120
|
if (split_only) {
|
|
1409
|
-
|
|
2121
|
+
if (sizes) {
|
|
2122
|
+
ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
|
|
2123
|
+
} else {
|
|
2124
|
+
ggml_backend_sched_split_graph(sched.get(), gf);
|
|
2125
|
+
}
|
|
1410
2126
|
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
2127
|
+
GGML_ASSERT(!sizes);
|
|
1411
2128
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
1412
2129
|
return nullptr;
|
|
1413
2130
|
}
|
|
@@ -1419,7 +2136,7 @@ llm_graph_params llama_context::graph_params(
|
|
|
1419
2136
|
llm_graph_result * res,
|
|
1420
2137
|
const llama_ubatch & ubatch,
|
|
1421
2138
|
const llama_memory_context_i * mctx,
|
|
1422
|
-
|
|
2139
|
+
llm_graph_type gtype) const {
|
|
1423
2140
|
return {
|
|
1424
2141
|
/*.arch =*/ model.arch,
|
|
1425
2142
|
/*.hparams =*/ model.hparams,
|
|
@@ -1428,10 +2145,11 @@ llm_graph_params llama_context::graph_params(
|
|
|
1428
2145
|
/*.gtype =*/ gtype,
|
|
1429
2146
|
/*.sched =*/ sched.get(),
|
|
1430
2147
|
/*.backend_cpu =*/ backend_cpu,
|
|
1431
|
-
/*.cvec =*/
|
|
1432
|
-
/*.loras =*/
|
|
2148
|
+
/*.cvec =*/ cvec.get(),
|
|
2149
|
+
/*.loras =*/ loras.get(),
|
|
1433
2150
|
/*.mctx =*/ mctx,
|
|
1434
2151
|
/*.cross =*/ &cross,
|
|
2152
|
+
/*.samplers =*/ sampling.samplers,
|
|
1435
2153
|
/*.n_outputs =*/ n_outputs,
|
|
1436
2154
|
/*.cb =*/ graph_get_cb(),
|
|
1437
2155
|
/*.res =*/ res,
|
|
@@ -1475,16 +2193,9 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
|
|
1475
2193
|
ggml_set_name(cur, name);
|
|
1476
2194
|
}
|
|
1477
2195
|
|
|
1478
|
-
if (!cparams.offload_kqv) {
|
|
1479
|
-
if (strcmp(name, "kqv_merged_cont") == 0) {
|
|
1480
|
-
// all nodes between the KV store and the attention output are run on the CPU
|
|
1481
|
-
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
|
1482
|
-
}
|
|
1483
|
-
}
|
|
1484
|
-
|
|
1485
2196
|
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
|
1486
2197
|
// FIXME: fix in ggml_backend_sched
|
|
1487
|
-
const bool full_offload = model.
|
|
2198
|
+
const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
|
|
1488
2199
|
if (ubatch.n_tokens < 32 || full_offload) {
|
|
1489
2200
|
if (il != -1 && strcmp(name, "norm") == 0) {
|
|
1490
2201
|
const auto & dev_layer = model.dev_layer(il);
|
|
@@ -1833,60 +2544,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
1833
2544
|
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
|
1834
2545
|
}
|
|
1835
2546
|
|
|
1836
|
-
// write output ids
|
|
1837
|
-
{
|
|
1838
|
-
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
|
1839
|
-
|
|
1840
|
-
const auto n_outputs = this->n_outputs;
|
|
1841
|
-
const auto & output_ids = this->output_ids;
|
|
1842
|
-
|
|
1843
|
-
std::vector<int32_t> w_output_pos;
|
|
1844
|
-
|
|
1845
|
-
w_output_pos.resize(n_outputs);
|
|
1846
|
-
|
|
1847
|
-
// build a more compact representation of the output ids
|
|
1848
|
-
for (size_t i = 0; i < n_batch(); ++i) {
|
|
1849
|
-
// map an output id to a position in the batch
|
|
1850
|
-
int64_t pos = output_ids[i];
|
|
1851
|
-
if (pos >= 0) {
|
|
1852
|
-
GGML_ASSERT(pos < n_outputs);
|
|
1853
|
-
w_output_pos[pos] = i;
|
|
1854
|
-
}
|
|
1855
|
-
}
|
|
1856
|
-
|
|
1857
|
-
io.write(&n_outputs, sizeof(n_outputs));
|
|
1858
|
-
|
|
1859
|
-
if (n_outputs) {
|
|
1860
|
-
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
|
1861
|
-
}
|
|
1862
|
-
}
|
|
1863
|
-
|
|
1864
|
-
// write logits
|
|
1865
|
-
{
|
|
1866
|
-
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
|
1867
|
-
|
|
1868
|
-
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
|
1869
|
-
|
|
1870
|
-
io.write(&logits_size, sizeof(logits_size));
|
|
1871
|
-
|
|
1872
|
-
if (logits_size) {
|
|
1873
|
-
io.write(logits, logits_size * sizeof(float));
|
|
1874
|
-
}
|
|
1875
|
-
}
|
|
1876
|
-
|
|
1877
|
-
// write embeddings
|
|
1878
|
-
{
|
|
1879
|
-
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
|
1880
|
-
|
|
1881
|
-
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
|
1882
|
-
|
|
1883
|
-
io.write(&embd_size, sizeof(embd_size));
|
|
1884
|
-
|
|
1885
|
-
if (embd_size) {
|
|
1886
|
-
io.write(embd, embd_size * sizeof(float));
|
|
1887
|
-
}
|
|
1888
|
-
}
|
|
1889
|
-
|
|
1890
2547
|
if (memory != nullptr) {
|
|
1891
2548
|
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
|
1892
2549
|
memory->state_write(io);
|
|
@@ -1912,67 +2569,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
1912
2569
|
// TODO: add more info which needs to be identical but which is not verified otherwise
|
|
1913
2570
|
}
|
|
1914
2571
|
|
|
1915
|
-
// read output ids
|
|
1916
|
-
{
|
|
1917
|
-
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
|
1918
|
-
|
|
1919
|
-
auto n_outputs = this->n_outputs;
|
|
1920
|
-
io.read_to(&n_outputs, sizeof(n_outputs));
|
|
1921
|
-
|
|
1922
|
-
if (n_outputs > output_reserve(n_outputs)) {
|
|
1923
|
-
throw std::runtime_error("could not reserve outputs");
|
|
1924
|
-
}
|
|
1925
|
-
|
|
1926
|
-
std::vector<int32_t> output_pos;
|
|
1927
|
-
|
|
1928
|
-
if (n_outputs) {
|
|
1929
|
-
output_pos.resize(n_outputs);
|
|
1930
|
-
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
|
1931
|
-
|
|
1932
|
-
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
|
1933
|
-
int32_t id = output_pos[i];
|
|
1934
|
-
if ((uint32_t) id >= n_batch()) {
|
|
1935
|
-
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
|
1936
|
-
}
|
|
1937
|
-
this->output_ids[id] = i;
|
|
1938
|
-
}
|
|
1939
|
-
|
|
1940
|
-
this->n_outputs = n_outputs;
|
|
1941
|
-
}
|
|
1942
|
-
}
|
|
1943
|
-
|
|
1944
|
-
// read logits
|
|
1945
|
-
{
|
|
1946
|
-
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
|
1947
|
-
|
|
1948
|
-
uint64_t logits_size;
|
|
1949
|
-
io.read_to(&logits_size, sizeof(logits_size));
|
|
1950
|
-
|
|
1951
|
-
if (this->logits_size < logits_size) {
|
|
1952
|
-
throw std::runtime_error("logits buffer too small");
|
|
1953
|
-
}
|
|
1954
|
-
|
|
1955
|
-
if (logits_size) {
|
|
1956
|
-
io.read_to(this->logits, logits_size * sizeof(float));
|
|
1957
|
-
}
|
|
1958
|
-
}
|
|
1959
|
-
|
|
1960
|
-
// read embeddings
|
|
1961
|
-
{
|
|
1962
|
-
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
|
1963
|
-
|
|
1964
|
-
uint64_t embd_size;
|
|
1965
|
-
io.read_to(&embd_size, sizeof(embd_size));
|
|
1966
|
-
|
|
1967
|
-
if (this->embd_size < embd_size) {
|
|
1968
|
-
throw std::runtime_error("embeddings buffer too small");
|
|
1969
|
-
}
|
|
1970
|
-
|
|
1971
|
-
if (embd_size) {
|
|
1972
|
-
io.read_to(this->embd, embd_size * sizeof(float));
|
|
1973
|
-
}
|
|
1974
|
-
}
|
|
1975
|
-
|
|
1976
2572
|
if (memory) {
|
|
1977
2573
|
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
|
1978
2574
|
|
|
@@ -2029,15 +2625,26 @@ void llama_context::perf_reset() {
|
|
|
2029
2625
|
|
|
2030
2626
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
|
2031
2627
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
|
2032
|
-
for (const auto &
|
|
2033
|
-
ret[
|
|
2628
|
+
for (const auto & [buft, size] : model.memory_breakdown()) {
|
|
2629
|
+
ret[buft].model += size;
|
|
2034
2630
|
}
|
|
2035
|
-
|
|
2036
|
-
|
|
2631
|
+
if (memory) {
|
|
2632
|
+
for (const auto & [buft, size] : memory->memory_breakdown()) {
|
|
2633
|
+
ret[buft].context += size;
|
|
2634
|
+
}
|
|
2037
2635
|
}
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2636
|
+
if (model.hparams.no_alloc) {
|
|
2637
|
+
for (size_t i = 0; i < backends.size(); ++i) {
|
|
2638
|
+
ggml_backend_t backend = backends[i].get();
|
|
2639
|
+
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
|
2640
|
+
ret[buft].compute += backend_buf_exp_size[i];
|
|
2641
|
+
}
|
|
2642
|
+
} else {
|
|
2643
|
+
for (const auto & backend_ptr : backends) {
|
|
2644
|
+
ggml_backend_t backend = backend_ptr.get();
|
|
2645
|
+
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
|
2646
|
+
ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
2647
|
+
}
|
|
2041
2648
|
}
|
|
2042
2649
|
return ret;
|
|
2043
2650
|
}
|
|
@@ -2094,6 +2701,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
|
|
|
2094
2701
|
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
|
2095
2702
|
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
|
2096
2703
|
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
|
2704
|
+
llama_set_param(model->cls_norm, param_filter, param_filter_ud);
|
|
2097
2705
|
|
|
2098
2706
|
for (struct llama_layer & layer : model->layers) {
|
|
2099
2707
|
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
|
@@ -2130,7 +2738,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2130
2738
|
batch.logits [pos_batch] = true;
|
|
2131
2739
|
}
|
|
2132
2740
|
|
|
2133
|
-
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.
|
|
2741
|
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
|
2134
2742
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
|
2135
2743
|
return;
|
|
2136
2744
|
}
|
|
@@ -2185,7 +2793,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2185
2793
|
};
|
|
2186
2794
|
ctx_compute_opt = ggml_init(params);
|
|
2187
2795
|
}
|
|
2188
|
-
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->
|
|
2796
|
+
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
|
|
2189
2797
|
ggml_opt_alloc(opt_ctx, train);
|
|
2190
2798
|
|
|
2191
2799
|
res->set_inputs(&ubatch);
|
|
@@ -2295,6 +2903,8 @@ llama_context_params llama_context_default_params() {
|
|
|
2295
2903
|
/*.op_offload =*/ true,
|
|
2296
2904
|
/*.swa_full =*/ true,
|
|
2297
2905
|
/*.kv_unified =*/ false,
|
|
2906
|
+
/*.sampler =*/ nullptr,
|
|
2907
|
+
/*.n_sampler =*/ 0,
|
|
2298
2908
|
};
|
|
2299
2909
|
|
|
2300
2910
|
return result;
|
|
@@ -2325,19 +2935,23 @@ llama_context * llama_init_from_model(
|
|
|
2325
2935
|
|
|
2326
2936
|
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
|
2327
2937
|
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
|
2328
|
-
|
|
2329
|
-
|
|
2330
|
-
|
|
2331
|
-
|
|
2938
|
+
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
|
2939
|
+
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
|
2940
|
+
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
|
2941
|
+
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
|
|
2942
|
+
return nullptr;
|
|
2943
|
+
}
|
|
2332
2944
|
}
|
|
2333
2945
|
}
|
|
2334
2946
|
|
|
2335
2947
|
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
|
2336
2948
|
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
|
2337
|
-
|
|
2338
|
-
|
|
2339
|
-
|
|
2340
|
-
|
|
2949
|
+
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
|
2950
|
+
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
|
2951
|
+
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
|
|
2952
|
+
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
|
|
2953
|
+
return nullptr;
|
|
2954
|
+
}
|
|
2341
2955
|
}
|
|
2342
2956
|
}
|
|
2343
2957
|
|
|
@@ -2346,6 +2960,13 @@ llama_context * llama_init_from_model(
|
|
|
2346
2960
|
return nullptr;
|
|
2347
2961
|
}
|
|
2348
2962
|
|
|
2963
|
+
if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
|
|
2964
|
+
params.pooling_type != model->hparams.pooling_type) {
|
|
2965
|
+
//user-specified pooling-type is different from the model default
|
|
2966
|
+
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
|
|
2967
|
+
model->hparams.pooling_type, params.pooling_type);
|
|
2968
|
+
}
|
|
2969
|
+
|
|
2349
2970
|
try {
|
|
2350
2971
|
auto * ctx = new llama_context(*model, params);
|
|
2351
2972
|
return ctx;
|
|
@@ -2371,6 +2992,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
|
|
2371
2992
|
return ctx->n_ctx();
|
|
2372
2993
|
}
|
|
2373
2994
|
|
|
2995
|
+
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
|
2996
|
+
return ctx->n_ctx_seq();
|
|
2997
|
+
}
|
|
2998
|
+
|
|
2374
2999
|
uint32_t llama_n_batch(const llama_context * ctx) {
|
|
2375
3000
|
return ctx->n_batch();
|
|
2376
3001
|
}
|
|
@@ -2443,7 +3068,15 @@ float * llama_get_logits(llama_context * ctx) {
|
|
|
2443
3068
|
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|
2444
3069
|
ctx->synchronize();
|
|
2445
3070
|
|
|
2446
|
-
|
|
3071
|
+
float * res = nullptr;
|
|
3072
|
+
|
|
3073
|
+
res = ctx->get_sampled_logits_ith(i);
|
|
3074
|
+
|
|
3075
|
+
if (!res) {
|
|
3076
|
+
res = ctx->get_logits_ith(i);
|
|
3077
|
+
}
|
|
3078
|
+
|
|
3079
|
+
return res;
|
|
2447
3080
|
}
|
|
2448
3081
|
|
|
2449
3082
|
float * llama_get_embeddings(llama_context * ctx) {
|
|
@@ -2464,37 +3097,89 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|
|
2464
3097
|
return ctx->get_embeddings_seq(seq_id);
|
|
2465
3098
|
}
|
|
2466
3099
|
|
|
2467
|
-
|
|
3100
|
+
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
|
3101
|
+
return ctx->set_sampler(seq_id, smpl);
|
|
3102
|
+
}
|
|
2468
3103
|
|
|
2469
|
-
int32_t
|
|
2470
|
-
|
|
2471
|
-
llama_adapter_lora * adapter,
|
|
2472
|
-
float scale) {
|
|
2473
|
-
ctx->set_adapter_lora(adapter, scale);
|
|
3104
|
+
llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
|
|
3105
|
+
ctx->synchronize();
|
|
2474
3106
|
|
|
2475
|
-
return
|
|
3107
|
+
return ctx->get_sampled_token_ith(i);
|
|
2476
3108
|
}
|
|
2477
3109
|
|
|
2478
|
-
int32_t
|
|
2479
|
-
|
|
2480
|
-
llama_adapter_lora * adapter) {
|
|
2481
|
-
bool res = ctx->rm_adapter_lora(adapter);
|
|
3110
|
+
float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
|
|
3111
|
+
ctx->synchronize();
|
|
2482
3112
|
|
|
2483
|
-
return
|
|
3113
|
+
return ctx->get_sampled_probs_ith(i);
|
|
2484
3114
|
}
|
|
2485
3115
|
|
|
2486
|
-
|
|
2487
|
-
ctx->
|
|
3116
|
+
float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
|
|
3117
|
+
ctx->synchronize();
|
|
3118
|
+
|
|
3119
|
+
return ctx->get_sampled_logits_ith(i);
|
|
3120
|
+
}
|
|
3121
|
+
|
|
3122
|
+
llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
|
|
3123
|
+
ctx->synchronize();
|
|
3124
|
+
|
|
3125
|
+
return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
|
|
3126
|
+
}
|
|
3127
|
+
|
|
3128
|
+
uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
|
|
3129
|
+
ctx->synchronize();
|
|
3130
|
+
|
|
3131
|
+
return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
|
|
3132
|
+
}
|
|
3133
|
+
|
|
3134
|
+
uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
|
|
3135
|
+
ctx->synchronize();
|
|
3136
|
+
|
|
3137
|
+
return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
|
|
3138
|
+
}
|
|
3139
|
+
|
|
3140
|
+
uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
|
3141
|
+
ctx->synchronize();
|
|
3142
|
+
|
|
3143
|
+
return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
|
|
3144
|
+
}
|
|
3145
|
+
|
|
3146
|
+
struct ggml_cgraph * llama_graph_reserve(
|
|
3147
|
+
struct llama_context * ctx,
|
|
3148
|
+
uint32_t n_tokens,
|
|
3149
|
+
uint32_t n_seqs,
|
|
3150
|
+
uint32_t n_outputs) {
|
|
3151
|
+
auto * memory = ctx->get_memory();
|
|
3152
|
+
llama_memory_context_ptr mctx;
|
|
3153
|
+
if (memory) {
|
|
3154
|
+
mctx = memory->init_full();
|
|
3155
|
+
}
|
|
3156
|
+
return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get());
|
|
3157
|
+
}
|
|
3158
|
+
|
|
3159
|
+
// llama adapter API
|
|
3160
|
+
|
|
3161
|
+
int32_t llama_set_adapters_lora(
|
|
3162
|
+
llama_context * ctx,
|
|
3163
|
+
llama_adapter_lora ** adapters,
|
|
3164
|
+
size_t n_adapters,
|
|
3165
|
+
float * scales) {
|
|
3166
|
+
if (adapters == nullptr || scales == nullptr) {
|
|
3167
|
+
GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call");
|
|
3168
|
+
}
|
|
3169
|
+
|
|
3170
|
+
ctx->set_adapters_lora(adapters, n_adapters, scales);
|
|
3171
|
+
|
|
3172
|
+
return 0;
|
|
2488
3173
|
}
|
|
2489
3174
|
|
|
2490
|
-
int32_t
|
|
3175
|
+
int32_t llama_set_adapter_cvec(
|
|
2491
3176
|
llama_context * ctx,
|
|
2492
|
-
|
|
2493
|
-
|
|
2494
|
-
|
|
2495
|
-
|
|
2496
|
-
|
|
2497
|
-
bool res = ctx->
|
|
3177
|
+
const float * data,
|
|
3178
|
+
size_t len,
|
|
3179
|
+
int32_t n_embd,
|
|
3180
|
+
int32_t il_start,
|
|
3181
|
+
int32_t il_end) {
|
|
3182
|
+
bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end);
|
|
2498
3183
|
|
|
2499
3184
|
return res ? 0 : -1;
|
|
2500
3185
|
}
|