whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#include "llama-memory.h"
|
|
8
8
|
#include "llama-mmap.h"
|
|
9
9
|
#include "llama-model.h"
|
|
10
|
+
#include "llama-ext.h"
|
|
10
11
|
|
|
11
12
|
#include <cinttypes>
|
|
12
13
|
#include <cmath>
|
|
@@ -22,6 +23,8 @@ llama_context::llama_context(
|
|
|
22
23
|
const llama_model & model,
|
|
23
24
|
llama_context_params params) :
|
|
24
25
|
model(model),
|
|
26
|
+
cvec(std::make_unique<llama_adapter_cvec>()),
|
|
27
|
+
loras(std::make_unique<llama_adapter_loras>()),
|
|
25
28
|
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
|
26
29
|
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
|
27
30
|
// may need to be backend-dependent
|
|
@@ -146,6 +149,11 @@ llama_context::llama_context(
|
|
|
146
149
|
}
|
|
147
150
|
|
|
148
151
|
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
152
|
+
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
|
|
153
|
+
|
|
154
|
+
cparams.fused_gdn_ar = true;
|
|
155
|
+
cparams.fused_gdn_ch = true;
|
|
156
|
+
cparams.auto_fgdn = true;
|
|
149
157
|
|
|
150
158
|
// with causal attention, the batch size is limited by the context size
|
|
151
159
|
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
|
@@ -155,6 +163,9 @@ llama_context::llama_context(
|
|
|
155
163
|
cparams.op_offload = params.op_offload;
|
|
156
164
|
cparams.kv_unified = params.kv_unified;
|
|
157
165
|
|
|
166
|
+
// initialized later
|
|
167
|
+
cparams.pipeline_parallel = false;
|
|
168
|
+
|
|
158
169
|
{
|
|
159
170
|
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
|
|
160
171
|
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
|
|
@@ -249,11 +260,7 @@ llama_context::llama_context(
|
|
|
249
260
|
|
|
250
261
|
// graph outputs buffer
|
|
251
262
|
{
|
|
252
|
-
|
|
253
|
-
// Create a dummy batch for initialization.
|
|
254
|
-
llama_batch dummy_batch = {};
|
|
255
|
-
dummy_batch.n_tokens = 0;
|
|
256
|
-
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
|
|
263
|
+
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
|
257
264
|
throw std::runtime_error("failed to reserve initial output buffer");
|
|
258
265
|
}
|
|
259
266
|
|
|
@@ -302,16 +309,6 @@ llama_context::llama_context(
|
|
|
302
309
|
|
|
303
310
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
|
304
311
|
|
|
305
|
-
const uint32_t n_seqs = cparams.n_seq_max;
|
|
306
|
-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
307
|
-
|
|
308
|
-
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
|
309
|
-
|
|
310
|
-
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
311
|
-
|
|
312
|
-
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
|
313
|
-
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
|
314
|
-
|
|
315
312
|
// TODO: move these checks to ggml_backend_sched
|
|
316
313
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
317
314
|
bool pipeline_parallel =
|
|
@@ -327,6 +324,7 @@ llama_context::llama_context(
|
|
|
327
324
|
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
|
328
325
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
|
329
326
|
// ignore CPU backend
|
|
327
|
+
// TODO: should we ignore ACCEL types too?
|
|
330
328
|
continue;
|
|
331
329
|
}
|
|
332
330
|
auto * dev = ggml_backend_get_device(backend.get());
|
|
@@ -340,177 +338,308 @@ llama_context::llama_context(
|
|
|
340
338
|
}
|
|
341
339
|
}
|
|
342
340
|
|
|
343
|
-
|
|
341
|
+
cparams.pipeline_parallel = pipeline_parallel;
|
|
344
342
|
|
|
345
|
-
if (pipeline_parallel) {
|
|
346
|
-
LLAMA_LOG_INFO("%s: pipeline parallelism enabled
|
|
343
|
+
if (cparams.pipeline_parallel) {
|
|
344
|
+
LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
|
|
345
|
+
|
|
346
|
+
if (!graph_reuse_disable) {
|
|
347
|
+
// TODO: figure out a way to make graph reuse work with pipeline parallelism
|
|
348
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/20463
|
|
349
|
+
LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__);
|
|
350
|
+
|
|
351
|
+
graph_reuse_disable = true;
|
|
352
|
+
}
|
|
347
353
|
}
|
|
348
354
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
throw std::runtime_error("failed to initialize memory module");
|
|
355
|
+
sched_reserve();
|
|
356
|
+
|
|
357
|
+
if (!cparams.flash_attn) {
|
|
358
|
+
if (ggml_is_quantized(params.type_v)) {
|
|
359
|
+
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
355
360
|
}
|
|
356
361
|
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Initialize the full vocabulary token ids for backend samplers.
|
|
365
|
+
{
|
|
366
|
+
const int n_vocab = model.vocab.n_tokens();
|
|
357
367
|
|
|
358
|
-
|
|
368
|
+
sampling.token_ids_full_vocab.resize(n_vocab);
|
|
369
|
+
for (int i = 0; i < n_vocab; ++i) {
|
|
370
|
+
sampling.token_ids_full_vocab[i] = i;
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
}
|
|
359
374
|
|
|
360
|
-
|
|
361
|
-
|
|
375
|
+
llama_context::~llama_context() {
|
|
376
|
+
if (!model.hparams.no_alloc) {
|
|
377
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
378
|
+
ggml_backend_t backend = backend_ptrs[i];
|
|
379
|
+
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
362
380
|
|
|
363
|
-
|
|
381
|
+
const size_t size_exp = backend_buf_exp_size[i];
|
|
382
|
+
const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
383
|
+
if (size_exp == size_act) {
|
|
384
|
+
LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
|
|
385
|
+
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
386
|
+
} else {
|
|
387
|
+
LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
|
388
|
+
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
389
|
+
}
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
ggml_opt_free(opt_ctx);
|
|
393
|
+
}
|
|
364
394
|
|
|
365
|
-
|
|
366
|
-
|
|
395
|
+
void llama_context::sched_reserve() {
|
|
396
|
+
if (!sched_need_reserve) {
|
|
397
|
+
return;
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
sched_need_reserve = false;
|
|
401
|
+
|
|
402
|
+
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
|
|
403
|
+
|
|
404
|
+
synchronize();
|
|
405
|
+
|
|
406
|
+
const int64_t t_start_us = ggml_time_us();
|
|
407
|
+
|
|
408
|
+
const uint32_t n_seqs = cparams.n_seq_max;
|
|
409
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
410
|
+
|
|
411
|
+
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
|
412
|
+
|
|
413
|
+
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
414
|
+
|
|
415
|
+
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
|
416
|
+
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
|
417
|
+
|
|
418
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
|
|
419
|
+
|
|
420
|
+
llama_memory_context_ptr mctx;
|
|
421
|
+
if (memory) {
|
|
422
|
+
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
|
423
|
+
mctx = memory->init_full();
|
|
424
|
+
if (!mctx) {
|
|
425
|
+
throw std::runtime_error("failed to initialize memory module");
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
// avoid reserving graphs with zero outputs - assume one output per sequence
|
|
430
|
+
const int n_outputs = n_seqs;
|
|
431
|
+
|
|
432
|
+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
433
|
+
|
|
434
|
+
// resolve automatic Flash Attention use
|
|
435
|
+
if (cparams.auto_fa) {
|
|
436
|
+
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
|
437
|
+
if (!gf) {
|
|
438
|
+
throw std::runtime_error("failed to reserve graph for Flash Attention check");
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
|
442
|
+
bool fa_device_mismatch = false;
|
|
443
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
444
|
+
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
445
|
+
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
|
446
|
+
continue;
|
|
447
|
+
}
|
|
448
|
+
ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
449
|
+
|
|
450
|
+
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
|
451
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
452
|
+
const int il = std::stoi(n->name + prefix_len);
|
|
453
|
+
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
454
|
+
if (device_fa != device_kv) {
|
|
455
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
|
456
|
+
"is assigned to device %s (usually due to missing support)\n",
|
|
457
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
|
458
|
+
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
|
459
|
+
fa_device_mismatch = true;
|
|
460
|
+
break;
|
|
461
|
+
}
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
if (fa_device_mismatch) {
|
|
465
|
+
cparams.flash_attn = false;
|
|
466
|
+
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
|
467
|
+
} else {
|
|
468
|
+
cparams.flash_attn = true;
|
|
469
|
+
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
cparams.auto_fa = false;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
if (cparams.auto_fgdn) {
|
|
476
|
+
LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
|
|
477
|
+
|
|
478
|
+
if (cparams.fused_gdn_ar) {
|
|
367
479
|
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
|
368
480
|
if (!gf) {
|
|
369
|
-
throw std::runtime_error("failed to
|
|
481
|
+
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
|
|
370
482
|
}
|
|
371
483
|
|
|
372
|
-
const size_t prefix_len = strlen(
|
|
373
|
-
bool
|
|
484
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
|
|
485
|
+
bool gdn_device_mismatch = false;
|
|
374
486
|
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
375
487
|
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
376
|
-
if (n->op !=
|
|
488
|
+
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
|
377
489
|
continue;
|
|
378
490
|
}
|
|
379
|
-
ggml_backend_dev_t
|
|
380
|
-
ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
491
|
+
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
381
492
|
|
|
382
|
-
|
|
383
|
-
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
493
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
|
|
384
494
|
const int il = std::stoi(n->name + prefix_len);
|
|
385
495
|
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
386
|
-
if (
|
|
387
|
-
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
fa_device_mismatch = true;
|
|
496
|
+
if (device_gdn != device_kv) {
|
|
497
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
|
498
|
+
"is assigned to device %s (usually due to missing support)\n",
|
|
499
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
|
500
|
+
gdn_device_mismatch = true;
|
|
392
501
|
break;
|
|
393
502
|
}
|
|
394
503
|
}
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
400
|
-
}
|
|
504
|
+
|
|
505
|
+
if (gdn_device_mismatch) {
|
|
506
|
+
cparams.fused_gdn_ar = false;
|
|
507
|
+
LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
|
|
401
508
|
} else {
|
|
402
|
-
|
|
403
|
-
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
509
|
+
LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
|
|
404
510
|
}
|
|
405
511
|
}
|
|
406
512
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
{
|
|
416
|
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
|
417
|
-
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
|
513
|
+
if (cparams.fused_gdn_ch) {
|
|
514
|
+
// more than one token in the batch per sequence in order to take the chunked path
|
|
515
|
+
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
|
|
516
|
+
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
|
|
517
|
+
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
|
|
518
|
+
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
|
|
519
|
+
const uint32_t n_tokens_ch = 16*n_seqs;
|
|
520
|
+
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
|
|
418
521
|
if (!gf) {
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
522
|
+
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
|
|
526
|
+
bool gdn_device_mismatch = false;
|
|
527
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
528
|
+
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
529
|
+
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
|
530
|
+
continue;
|
|
423
531
|
}
|
|
424
|
-
|
|
425
|
-
|
|
532
|
+
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
533
|
+
|
|
534
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
|
|
535
|
+
const int il = std::stoi(n->name + prefix_len);
|
|
536
|
+
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
537
|
+
if (device_gdn != device_kv) {
|
|
538
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
|
539
|
+
"is assigned to device %s (usually due to missing support)\n",
|
|
540
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
|
541
|
+
gdn_device_mismatch = true;
|
|
542
|
+
break;
|
|
426
543
|
}
|
|
427
544
|
}
|
|
428
545
|
|
|
429
|
-
|
|
430
|
-
|
|
546
|
+
if (gdn_device_mismatch) {
|
|
547
|
+
cparams.fused_gdn_ch = false;
|
|
548
|
+
LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
|
|
549
|
+
} else {
|
|
550
|
+
LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
|
|
551
|
+
}
|
|
431
552
|
}
|
|
432
553
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
|
436
|
-
if (!gf) {
|
|
437
|
-
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
438
|
-
}
|
|
554
|
+
cparams.auto_fgdn = false;
|
|
555
|
+
}
|
|
439
556
|
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
557
|
+
// reserve worst-case graph
|
|
558
|
+
int n_splits_pp = -1;
|
|
559
|
+
int n_nodes_pp = -1;
|
|
443
560
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
561
|
+
int n_splits_tg = -1;
|
|
562
|
+
int n_nodes_tg = -1;
|
|
563
|
+
|
|
564
|
+
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
|
565
|
+
{
|
|
566
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
|
567
|
+
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
|
568
|
+
if (!gf) {
|
|
569
|
+
if (cparams.pipeline_parallel) {
|
|
570
|
+
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
|
571
|
+
cparams.pipeline_parallel = false;
|
|
572
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
|
|
573
|
+
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
574
|
+
}
|
|
451
575
|
if (!gf) {
|
|
452
576
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
453
577
|
}
|
|
454
578
|
}
|
|
455
579
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
if (!model.hparams.no_alloc) {
|
|
460
|
-
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
461
|
-
}
|
|
462
|
-
if (backend_buf_exp_size[i] > 1) {
|
|
463
|
-
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
464
|
-
ggml_backend_buft_name(buft),
|
|
465
|
-
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
|
466
|
-
}
|
|
467
|
-
}
|
|
580
|
+
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
|
581
|
+
n_nodes_pp = ggml_graph_n_nodes(gf);
|
|
582
|
+
}
|
|
468
583
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
584
|
+
// reserve with tg (token generation) graph to get the number of splits and nodes
|
|
585
|
+
{
|
|
586
|
+
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
|
587
|
+
if (!gf) {
|
|
588
|
+
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
473
589
|
}
|
|
474
590
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
} else {
|
|
478
|
-
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
479
|
-
}
|
|
591
|
+
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
|
592
|
+
n_nodes_tg = ggml_graph_n_nodes(gf);
|
|
480
593
|
}
|
|
481
594
|
|
|
482
|
-
//
|
|
595
|
+
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
483
596
|
{
|
|
484
|
-
|
|
597
|
+
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
|
598
|
+
//
|
|
599
|
+
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
|
600
|
+
//
|
|
601
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
|
|
602
|
+
if (!gf) {
|
|
603
|
+
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
604
|
+
}
|
|
605
|
+
}
|
|
485
606
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
607
|
+
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
|
608
|
+
ggml_backend_t backend = backend_ptrs[i];
|
|
609
|
+
ggml_backend_buffer_type_t buft = backend_buft[i];
|
|
610
|
+
if (!model.hparams.no_alloc) {
|
|
611
|
+
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
|
612
|
+
}
|
|
613
|
+
if (backend_buf_exp_size[i] > 1) {
|
|
614
|
+
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
|
615
|
+
ggml_backend_buft_name(buft),
|
|
616
|
+
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
|
489
617
|
}
|
|
490
618
|
}
|
|
491
|
-
}
|
|
492
619
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
620
|
+
if (n_nodes_pp == n_nodes_tg) {
|
|
621
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
|
622
|
+
} else {
|
|
623
|
+
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
|
624
|
+
}
|
|
498
625
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
504
|
-
} else {
|
|
505
|
-
LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
|
506
|
-
__func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
|
507
|
-
}
|
|
508
|
-
}
|
|
626
|
+
if (n_splits_pp == n_splits_tg) {
|
|
627
|
+
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
|
628
|
+
} else {
|
|
629
|
+
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
|
509
630
|
}
|
|
510
|
-
|
|
631
|
+
|
|
632
|
+
const int64_t t_end_us = ggml_time_us();
|
|
633
|
+
|
|
634
|
+
LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n",
|
|
635
|
+
__func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get()));
|
|
511
636
|
}
|
|
512
637
|
|
|
513
638
|
void llama_context::synchronize() {
|
|
639
|
+
if (!sched) {
|
|
640
|
+
return;
|
|
641
|
+
}
|
|
642
|
+
|
|
514
643
|
ggml_backend_sched_synchronize(sched.get());
|
|
515
644
|
|
|
516
645
|
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
|
@@ -645,7 +774,7 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
|
|
645
774
|
float * llama_context::get_logits() {
|
|
646
775
|
output_reorder();
|
|
647
776
|
|
|
648
|
-
return logits;
|
|
777
|
+
return logits.data;
|
|
649
778
|
}
|
|
650
779
|
|
|
651
780
|
int64_t llama_context::output_resolve_row(int32_t i) const {
|
|
@@ -678,36 +807,15 @@ int64_t llama_context::output_resolve_row(int32_t i) const {
|
|
|
678
807
|
}
|
|
679
808
|
|
|
680
809
|
float * llama_context::get_logits_ith(int32_t i) {
|
|
681
|
-
int64_t j = -1;
|
|
682
|
-
|
|
683
810
|
output_reorder();
|
|
684
811
|
|
|
685
812
|
try {
|
|
686
|
-
if (logits == nullptr) {
|
|
813
|
+
if (logits.data == nullptr) {
|
|
687
814
|
throw std::runtime_error("no logits");
|
|
688
815
|
}
|
|
689
816
|
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
j = n_outputs + i;
|
|
693
|
-
if (j < 0) {
|
|
694
|
-
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
695
|
-
}
|
|
696
|
-
} else if ((size_t) i >= output_ids.size()) {
|
|
697
|
-
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
698
|
-
} else {
|
|
699
|
-
j = output_ids[i];
|
|
700
|
-
}
|
|
701
|
-
|
|
702
|
-
if (j < 0) {
|
|
703
|
-
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
704
|
-
}
|
|
705
|
-
if (j >= n_outputs) {
|
|
706
|
-
// This should not happen
|
|
707
|
-
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
708
|
-
}
|
|
709
|
-
|
|
710
|
-
return logits + j*model.vocab.n_tokens();
|
|
817
|
+
const int64_t j = output_resolve_row(i);
|
|
818
|
+
return logits.data + j*model.vocab.n_tokens();
|
|
711
819
|
} catch (const std::exception & err) {
|
|
712
820
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
|
713
821
|
#ifndef NDEBUG
|
|
@@ -721,45 +829,24 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
721
829
|
float * llama_context::get_embeddings() {
|
|
722
830
|
output_reorder();
|
|
723
831
|
|
|
724
|
-
return embd;
|
|
832
|
+
return embd.data;
|
|
725
833
|
}
|
|
726
834
|
|
|
727
835
|
llama_token * llama_context::get_sampled_tokens() const{
|
|
728
|
-
return sampling.sampled;
|
|
836
|
+
return sampling.sampled.data;
|
|
729
837
|
}
|
|
730
838
|
|
|
731
839
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
|
732
|
-
int64_t j = -1;
|
|
733
|
-
|
|
734
840
|
output_reorder();
|
|
735
841
|
|
|
736
842
|
try {
|
|
737
|
-
if (embd == nullptr) {
|
|
843
|
+
if (embd.data == nullptr) {
|
|
738
844
|
throw std::runtime_error("no embeddings");
|
|
739
845
|
}
|
|
740
846
|
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
if (j < 0) {
|
|
745
|
-
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
|
746
|
-
}
|
|
747
|
-
} else if ((size_t) i >= output_ids.size()) {
|
|
748
|
-
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
|
749
|
-
} else {
|
|
750
|
-
j = output_ids[i];
|
|
751
|
-
}
|
|
752
|
-
|
|
753
|
-
if (j < 0) {
|
|
754
|
-
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
755
|
-
}
|
|
756
|
-
if (j >= n_outputs) {
|
|
757
|
-
// This should not happen
|
|
758
|
-
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
|
759
|
-
}
|
|
760
|
-
|
|
761
|
-
const uint32_t n_embd_out = model.hparams.get_n_embd_out();
|
|
762
|
-
return embd + j*n_embd_out;
|
|
847
|
+
const int64_t j = output_resolve_row(i);
|
|
848
|
+
const uint32_t n_embd_out = model.hparams.n_embd_out();
|
|
849
|
+
return embd.data + j*n_embd_out;
|
|
763
850
|
} catch (const std::exception & err) {
|
|
764
851
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
765
852
|
#ifndef NDEBUG
|
|
@@ -782,14 +869,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
782
869
|
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
|
783
870
|
output_reorder();
|
|
784
871
|
|
|
785
|
-
if (sampling.sampled
|
|
872
|
+
if (!sampling.sampled.has_data()) {
|
|
786
873
|
return LLAMA_TOKEN_NULL;
|
|
787
874
|
}
|
|
788
875
|
|
|
789
876
|
try {
|
|
790
877
|
const int64_t row = output_resolve_row(idx);
|
|
791
|
-
GGML_ASSERT(row < (int64_t) sampling.
|
|
792
|
-
return sampling.sampled[row];
|
|
878
|
+
GGML_ASSERT(row < (int64_t) sampling.sampled.size);
|
|
879
|
+
return sampling.sampled.data[row];
|
|
793
880
|
} catch (const std::exception & err) {
|
|
794
881
|
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
|
795
882
|
return LLAMA_TOKEN_NULL;
|
|
@@ -799,7 +886,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
|
|
799
886
|
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|
800
887
|
output_reorder();
|
|
801
888
|
|
|
802
|
-
if (sampling.probs
|
|
889
|
+
if (!sampling.probs.has_data()) {
|
|
803
890
|
return nullptr;
|
|
804
891
|
}
|
|
805
892
|
|
|
@@ -808,7 +895,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|
|
808
895
|
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
|
809
896
|
return nullptr;
|
|
810
897
|
}
|
|
811
|
-
return sampling.probs + row*model.vocab.n_tokens();
|
|
898
|
+
return sampling.probs.data + row*model.vocab.n_tokens();
|
|
812
899
|
} catch (const std::exception & err) {
|
|
813
900
|
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
|
814
901
|
return nullptr;
|
|
@@ -818,7 +905,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|
|
818
905
|
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
|
819
906
|
output_reorder();
|
|
820
907
|
|
|
821
|
-
if (sampling.logits
|
|
908
|
+
if (!sampling.logits.has_data()) {
|
|
822
909
|
return nullptr;
|
|
823
910
|
}
|
|
824
911
|
|
|
@@ -827,7 +914,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
|
|
827
914
|
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
|
828
915
|
return nullptr;
|
|
829
916
|
}
|
|
830
|
-
return sampling.logits + row*model.vocab.n_tokens();
|
|
917
|
+
return sampling.logits.data + row*model.vocab.n_tokens();
|
|
831
918
|
} catch (const std::exception & err) {
|
|
832
919
|
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
|
833
920
|
return nullptr;
|
|
@@ -839,13 +926,14 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|
|
839
926
|
|
|
840
927
|
try {
|
|
841
928
|
const int64_t row = output_resolve_row(idx);
|
|
842
|
-
if (sampling.candidates
|
|
929
|
+
if (sampling.candidates.has_data() &&
|
|
843
930
|
(size_t) row < sampling.candidates_count.size() &&
|
|
844
931
|
sampling.candidates_count[row] > 0) {
|
|
845
|
-
return sampling.candidates + row*model.vocab.n_tokens();
|
|
932
|
+
return sampling.candidates.data + row*model.vocab.n_tokens();
|
|
846
933
|
}
|
|
847
934
|
} catch (const std::exception & err) {
|
|
848
935
|
// fallback to full vocab list
|
|
936
|
+
GGML_UNUSED(err);
|
|
849
937
|
}
|
|
850
938
|
|
|
851
939
|
return sampling.token_ids_full_vocab.data();
|
|
@@ -854,7 +942,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|
|
854
942
|
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
|
855
943
|
output_reorder();
|
|
856
944
|
|
|
857
|
-
if (sampling.candidates
|
|
945
|
+
if (!sampling.candidates.has_data()) {
|
|
858
946
|
return 0;
|
|
859
947
|
}
|
|
860
948
|
|
|
@@ -873,7 +961,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
|
|
873
961
|
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|
874
962
|
output_reorder();
|
|
875
963
|
|
|
876
|
-
if (sampling.logits
|
|
964
|
+
if (!sampling.logits.has_data()) {
|
|
877
965
|
return model.vocab.n_tokens();
|
|
878
966
|
}
|
|
879
967
|
|
|
@@ -892,7 +980,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|
|
892
980
|
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
|
893
981
|
output_reorder();
|
|
894
982
|
|
|
895
|
-
if (sampling.probs
|
|
983
|
+
if (!sampling.probs.has_data()) {
|
|
896
984
|
return 0;
|
|
897
985
|
}
|
|
898
986
|
|
|
@@ -951,21 +1039,41 @@ void llama_context::set_embeddings(bool value) {
|
|
|
951
1039
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
952
1040
|
|
|
953
1041
|
cparams.embeddings = value;
|
|
1042
|
+
|
|
1043
|
+
// TODO: not sure yet if we want to reserve here
|
|
1044
|
+
//sched_need_reserve = true;
|
|
954
1045
|
}
|
|
955
1046
|
|
|
956
1047
|
void llama_context::set_causal_attn(bool value) {
|
|
957
1048
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
958
1049
|
|
|
1050
|
+
if (cparams.causal_attn == value) {
|
|
1051
|
+
return;
|
|
1052
|
+
}
|
|
1053
|
+
|
|
959
1054
|
cparams.causal_attn = value;
|
|
1055
|
+
|
|
1056
|
+
sched_need_reserve = true;
|
|
960
1057
|
}
|
|
961
1058
|
|
|
962
1059
|
void llama_context::set_warmup(bool value) {
|
|
963
1060
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
|
964
1061
|
|
|
1062
|
+
if (cparams.warmup == value) {
|
|
1063
|
+
return;
|
|
1064
|
+
}
|
|
1065
|
+
|
|
965
1066
|
cparams.warmup = value;
|
|
1067
|
+
|
|
1068
|
+
// warmups are usually with small batches, so no need to reserve
|
|
1069
|
+
//sched_need_reserve = true;
|
|
966
1070
|
}
|
|
967
1071
|
|
|
968
1072
|
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|
1073
|
+
if (!sampler && sampling.samplers.count(seq_id) == 0) {
|
|
1074
|
+
return true;
|
|
1075
|
+
}
|
|
1076
|
+
|
|
969
1077
|
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
|
970
1078
|
|
|
971
1079
|
const bool can_offload =
|
|
@@ -975,22 +1083,24 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|
|
975
1083
|
llama_sampler_chain_n(sampler) > 0;
|
|
976
1084
|
|
|
977
1085
|
if (sampler && can_offload) {
|
|
978
|
-
|
|
979
|
-
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
|
|
980
|
-
if (host_buft) {
|
|
981
|
-
buft = host_buft;
|
|
982
|
-
}
|
|
1086
|
+
auto * buft = ggml_backend_dev_buffer_type(model.dev_output());
|
|
983
1087
|
|
|
984
1088
|
sampler->iface->backend_init(sampler, buft);
|
|
985
1089
|
|
|
986
1090
|
sampling.samplers[seq_id] = sampler;
|
|
987
1091
|
|
|
1092
|
+
sched_need_reserve = true;
|
|
1093
|
+
|
|
988
1094
|
return true;
|
|
989
1095
|
}
|
|
990
1096
|
|
|
991
1097
|
if (sampler && !can_offload) {
|
|
992
1098
|
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
|
993
1099
|
|
|
1100
|
+
if (sampling.samplers.count(seq_id) > 0) {
|
|
1101
|
+
sched_need_reserve = true;
|
|
1102
|
+
}
|
|
1103
|
+
|
|
994
1104
|
sampling.samplers.erase(seq_id);
|
|
995
1105
|
|
|
996
1106
|
return false;
|
|
@@ -998,37 +1108,56 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|
|
998
1108
|
|
|
999
1109
|
sampling.samplers.erase(seq_id);
|
|
1000
1110
|
|
|
1111
|
+
sched_need_reserve = true;
|
|
1112
|
+
|
|
1001
1113
|
return true;
|
|
1002
1114
|
}
|
|
1003
1115
|
|
|
1004
|
-
void llama_context::
|
|
1005
|
-
|
|
1006
|
-
float scale) {
|
|
1007
|
-
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
|
|
1116
|
+
void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
|
1117
|
+
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
|
1008
1118
|
|
|
1009
|
-
|
|
1010
|
-
|
|
1119
|
+
if (adapters_lora_are_same(adapters, n_adapters, scales)) {
|
|
1120
|
+
return;
|
|
1121
|
+
}
|
|
1011
1122
|
|
|
1012
|
-
|
|
1013
|
-
llama_adapter_lora * adapter) {
|
|
1014
|
-
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
|
|
1123
|
+
loras.reset(new llama_adapter_loras());
|
|
1015
1124
|
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1125
|
+
for (size_t i = 0; i < n_adapters; i ++) {
|
|
1126
|
+
if (scales[i] != 0.0f) {
|
|
1127
|
+
loras->insert({adapters[i], scales[i]});
|
|
1128
|
+
}
|
|
1020
1129
|
}
|
|
1021
1130
|
|
|
1022
|
-
|
|
1131
|
+
sched_need_reserve = true;
|
|
1023
1132
|
}
|
|
1024
1133
|
|
|
1025
|
-
|
|
1026
|
-
LLAMA_LOG_DEBUG("%s:
|
|
1134
|
+
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
|
1135
|
+
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
|
1136
|
+
|
|
1137
|
+
// Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison.
|
|
1138
|
+
size_t n_non_zero = 0;
|
|
1139
|
+
|
|
1140
|
+
for (size_t i = 0; i < n_adapters; i ++) {
|
|
1141
|
+
if (scales[i] == 0.0f) {
|
|
1142
|
+
continue;
|
|
1143
|
+
}
|
|
1144
|
+
n_non_zero++;
|
|
1145
|
+
|
|
1146
|
+
auto it = loras->find(adapters[i]);
|
|
1027
1147
|
|
|
1028
|
-
|
|
1148
|
+
if (it == loras->end() || it->second != scales[i]) {
|
|
1149
|
+
return false;
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
if (n_non_zero != loras->size()) {
|
|
1154
|
+
return false;
|
|
1155
|
+
}
|
|
1156
|
+
|
|
1157
|
+
return true;
|
|
1029
1158
|
}
|
|
1030
1159
|
|
|
1031
|
-
bool llama_context::
|
|
1160
|
+
bool llama_context::set_adapter_cvec(
|
|
1032
1161
|
const float * data,
|
|
1033
1162
|
size_t len,
|
|
1034
1163
|
int32_t n_embd,
|
|
@@ -1036,7 +1165,9 @@ bool llama_context::apply_adapter_cvec(
|
|
|
1036
1165
|
int32_t il_end) {
|
|
1037
1166
|
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
|
|
1038
1167
|
|
|
1039
|
-
|
|
1168
|
+
// TODO: should we reserve?
|
|
1169
|
+
|
|
1170
|
+
return cvec->apply(model, data, len, n_embd, il_start, il_end);
|
|
1040
1171
|
}
|
|
1041
1172
|
|
|
1042
1173
|
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
|
@@ -1086,6 +1217,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|
|
1086
1217
|
{
|
|
1087
1218
|
//const auto t_start_us = ggml_time_us();
|
|
1088
1219
|
|
|
1220
|
+
// FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated
|
|
1089
1221
|
res->set_inputs(&ubatch);
|
|
1090
1222
|
|
|
1091
1223
|
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
|
@@ -1138,10 +1270,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
1138
1270
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
1139
1271
|
embd_seq.clear();
|
|
1140
1272
|
|
|
1273
|
+
sched_reserve();
|
|
1274
|
+
|
|
1141
1275
|
n_queued_tokens += n_tokens;
|
|
1142
1276
|
|
|
1143
1277
|
// reserve output buffer
|
|
1144
|
-
if (output_reserve(n_tokens
|
|
1278
|
+
if (output_reserve(n_tokens) < n_tokens) {
|
|
1145
1279
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
1146
1280
|
return -2;
|
|
1147
1281
|
};
|
|
@@ -1177,16 +1311,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
1177
1311
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
|
1178
1312
|
|
|
1179
1313
|
// extract logits
|
|
1180
|
-
|
|
1314
|
+
if (logits.data && t_logits) {
|
|
1181
1315
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
1182
1316
|
GGML_ASSERT(backend_res != nullptr);
|
|
1183
|
-
GGML_ASSERT(logits != nullptr);
|
|
1317
|
+
GGML_ASSERT(logits.data != nullptr);
|
|
1184
1318
|
|
|
1185
|
-
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
|
|
1319
|
+
ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
|
|
1186
1320
|
}
|
|
1187
1321
|
|
|
1188
1322
|
// extract embeddings
|
|
1189
|
-
if (embd && t_embd) {
|
|
1323
|
+
if (embd.data && t_embd) {
|
|
1190
1324
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1191
1325
|
GGML_ASSERT(backend_embd != nullptr);
|
|
1192
1326
|
|
|
@@ -1194,11 +1328,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
1194
1328
|
case LLAMA_POOLING_TYPE_NONE:
|
|
1195
1329
|
{
|
|
1196
1330
|
// extract token embeddings
|
|
1197
|
-
GGML_ASSERT(embd != nullptr);
|
|
1198
|
-
const uint32_t n_embd_out = hparams.
|
|
1331
|
+
GGML_ASSERT(embd.data != nullptr);
|
|
1332
|
+
const uint32_t n_embd_out = hparams.n_embd_out();
|
|
1199
1333
|
|
|
1200
|
-
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t)
|
|
1201
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
|
|
1334
|
+
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
|
|
1335
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
|
|
1202
1336
|
} break;
|
|
1203
1337
|
case LLAMA_POOLING_TYPE_MEAN:
|
|
1204
1338
|
case LLAMA_POOLING_TYPE_CLS:
|
|
@@ -1246,7 +1380,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
1246
1380
|
cross.n_embd = t_embd->ne[0];
|
|
1247
1381
|
cross.n_enc = t_embd->ne[1];
|
|
1248
1382
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
|
1249
|
-
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
|
1383
|
+
memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
|
|
1250
1384
|
|
|
1251
1385
|
const auto & batch = balloc->get_batch();
|
|
1252
1386
|
|
|
@@ -1286,11 +1420,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
|
|
|
1286
1420
|
|
|
1287
1421
|
static void copy_tensor_async_ints(
|
|
1288
1422
|
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1289
|
-
llama_token
|
|
1290
|
-
size_t sampled_size,
|
|
1423
|
+
const buffer_view<llama_token> & sampled,
|
|
1291
1424
|
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1292
1425
|
ggml_backend_sched_t sched) {
|
|
1293
|
-
if (sampled
|
|
1426
|
+
if (!sampled.has_data()) {
|
|
1294
1427
|
return;
|
|
1295
1428
|
}
|
|
1296
1429
|
|
|
@@ -1301,23 +1434,23 @@ static void copy_tensor_async_ints(
|
|
|
1301
1434
|
}
|
|
1302
1435
|
|
|
1303
1436
|
const uint32_t row = it->second;
|
|
1304
|
-
GGML_ASSERT(row <
|
|
1437
|
+
GGML_ASSERT(row < sampled.size);
|
|
1305
1438
|
|
|
1306
1439
|
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
|
1307
1440
|
|
|
1308
1441
|
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1309
|
-
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
|
|
1442
|
+
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
|
|
1310
1443
|
}
|
|
1311
1444
|
}
|
|
1312
1445
|
|
|
1313
1446
|
static void copy_tensor_async_floats(
|
|
1314
1447
|
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1315
|
-
float
|
|
1448
|
+
const buffer_view<float> & dst,
|
|
1316
1449
|
size_t stride,
|
|
1317
1450
|
std::vector<uint32_t> & counts,
|
|
1318
1451
|
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1319
1452
|
ggml_backend_sched_t sched) {
|
|
1320
|
-
if (dst
|
|
1453
|
+
if (!dst.has_data()) {
|
|
1321
1454
|
return;
|
|
1322
1455
|
}
|
|
1323
1456
|
|
|
@@ -1333,7 +1466,7 @@ static void copy_tensor_async_floats(
|
|
|
1333
1466
|
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
|
1334
1467
|
|
|
1335
1468
|
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1336
|
-
float * row_ptr = dst + (size_t) row * stride;
|
|
1469
|
+
float * row_ptr = dst.data + (size_t) row * stride;
|
|
1337
1470
|
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1338
1471
|
|
|
1339
1472
|
// Update the actual number of logits/probabilities that were written for this row.
|
|
@@ -1343,12 +1476,12 @@ static void copy_tensor_async_floats(
|
|
|
1343
1476
|
|
|
1344
1477
|
static void copy_tensor_async_candidates(
|
|
1345
1478
|
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
|
1346
|
-
llama_token
|
|
1479
|
+
const buffer_view<llama_token> & dst,
|
|
1347
1480
|
size_t stride,
|
|
1348
1481
|
std::vector<uint32_t> & counts,
|
|
1349
1482
|
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
|
1350
1483
|
ggml_backend_sched_t sched) {
|
|
1351
|
-
if (dst
|
|
1484
|
+
if (!dst.has_data()) {
|
|
1352
1485
|
return;
|
|
1353
1486
|
}
|
|
1354
1487
|
|
|
@@ -1364,7 +1497,7 @@ static void copy_tensor_async_candidates(
|
|
|
1364
1497
|
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
|
1365
1498
|
|
|
1366
1499
|
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
|
1367
|
-
llama_token * row_ptr = dst + (size_t) row * stride;
|
|
1500
|
+
llama_token * row_ptr = dst.data + (size_t) row * stride;
|
|
1368
1501
|
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
|
1369
1502
|
|
|
1370
1503
|
// Update the actual number of candidates that were written.
|
|
@@ -1372,6 +1505,23 @@ static void copy_tensor_async_candidates(
|
|
|
1372
1505
|
}
|
|
1373
1506
|
}
|
|
1374
1507
|
|
|
1508
|
+
static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
|
|
1509
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
1510
|
+
if (!ubatch.output[i]) {
|
|
1511
|
+
continue;
|
|
1512
|
+
}
|
|
1513
|
+
|
|
1514
|
+
// Check if the output token has at least one sequence without a backend sampler.
|
|
1515
|
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
|
1516
|
+
llama_seq_id seq_id = ubatch.seq_id[i][j];
|
|
1517
|
+
if (samplers.find(seq_id) == samplers.end()) {
|
|
1518
|
+
return true;
|
|
1519
|
+
}
|
|
1520
|
+
}
|
|
1521
|
+
}
|
|
1522
|
+
return false; // all sequences use backend sampling
|
|
1523
|
+
}
|
|
1524
|
+
|
|
1375
1525
|
int llama_context::decode(const llama_batch & batch_inp) {
|
|
1376
1526
|
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
|
1377
1527
|
|
|
@@ -1451,6 +1601,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1451
1601
|
embd_seq.clear();
|
|
1452
1602
|
output_swaps.clear();
|
|
1453
1603
|
|
|
1604
|
+
sched_reserve();
|
|
1605
|
+
|
|
1454
1606
|
bool did_optimize = false;
|
|
1455
1607
|
|
|
1456
1608
|
// handle any pending shifts/copies
|
|
@@ -1502,7 +1654,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1502
1654
|
}
|
|
1503
1655
|
|
|
1504
1656
|
// reserve output buffer
|
|
1505
|
-
if (output_reserve(n_outputs_all
|
|
1657
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1506
1658
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
1507
1659
|
return -2;
|
|
1508
1660
|
};
|
|
@@ -1575,25 +1727,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1575
1727
|
}
|
|
1576
1728
|
|
|
1577
1729
|
// extract logits
|
|
1578
|
-
|
|
1579
|
-
// this is currently inefficient as we copy all logits even for the
|
|
1580
|
-
// backend sampled tokens.
|
|
1581
|
-
if (logits && t_logits && n_outputs > 0) {
|
|
1730
|
+
if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
|
|
1582
1731
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
|
1583
1732
|
GGML_ASSERT(backend_res != nullptr);
|
|
1584
|
-
GGML_ASSERT(logits != nullptr);
|
|
1733
|
+
GGML_ASSERT(logits.data != nullptr);
|
|
1585
1734
|
|
|
1586
|
-
float * logits_out = logits + n_outputs_prev*n_vocab;
|
|
1735
|
+
float * logits_out = logits.data + n_outputs_prev*n_vocab;
|
|
1587
1736
|
|
|
1588
1737
|
if (n_outputs) {
|
|
1589
1738
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1590
|
-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t)
|
|
1739
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
|
|
1591
1740
|
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
|
1592
1741
|
}
|
|
1593
1742
|
}
|
|
1594
1743
|
|
|
1595
1744
|
// extract embeddings
|
|
1596
|
-
if (embd && t_embd && n_outputs > 0) {
|
|
1745
|
+
if (embd.data && t_embd && n_outputs > 0) {
|
|
1597
1746
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1598
1747
|
GGML_ASSERT(backend_embd != nullptr);
|
|
1599
1748
|
|
|
@@ -1601,13 +1750,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1601
1750
|
case LLAMA_POOLING_TYPE_NONE:
|
|
1602
1751
|
{
|
|
1603
1752
|
// extract token embeddings
|
|
1604
|
-
GGML_ASSERT(embd != nullptr);
|
|
1605
|
-
const uint32_t n_embd_out = hparams.
|
|
1606
|
-
float * embd_out = embd + n_outputs_prev*n_embd_out;
|
|
1753
|
+
GGML_ASSERT(embd.data != nullptr);
|
|
1754
|
+
const uint32_t n_embd_out = hparams.n_embd_out();
|
|
1755
|
+
float * embd_out = embd.data + n_outputs_prev*n_embd_out;
|
|
1607
1756
|
|
|
1608
1757
|
if (n_outputs) {
|
|
1609
1758
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
|
1610
|
-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t)
|
|
1759
|
+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
|
|
1611
1760
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
|
|
1612
1761
|
}
|
|
1613
1762
|
} break;
|
|
@@ -1648,16 +1797,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1648
1797
|
}
|
|
1649
1798
|
}
|
|
1650
1799
|
|
|
1651
|
-
//
|
|
1652
|
-
|
|
1653
|
-
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
|
1654
|
-
|
|
1655
|
-
if (has_samplers && has_sampled) {
|
|
1800
|
+
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
|
1801
|
+
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
|
|
1656
1802
|
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
|
1657
1803
|
const auto stride = n_vocab;
|
|
1658
1804
|
|
|
1659
1805
|
// async copy the sampling data from the backend to the host
|
|
1660
|
-
copy_tensor_async_ints(res->t_sampled, sampling.sampled,
|
|
1806
|
+
copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
|
|
1661
1807
|
|
|
1662
1808
|
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
|
1663
1809
|
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
|
@@ -1727,7 +1873,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
1727
1873
|
// output
|
|
1728
1874
|
//
|
|
1729
1875
|
|
|
1730
|
-
uint32_t llama_context::output_reserve(int32_t n_outputs
|
|
1876
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1731
1877
|
const auto & hparams = model.hparams;
|
|
1732
1878
|
const auto & vocab = model.vocab;
|
|
1733
1879
|
|
|
@@ -1735,7 +1881,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|
|
1735
1881
|
|
|
1736
1882
|
const auto n_batch = cparams.n_batch;
|
|
1737
1883
|
const auto n_vocab = vocab.n_tokens();
|
|
1738
|
-
const auto n_embd_out = hparams.
|
|
1884
|
+
const auto n_embd_out = hparams.n_embd_out();
|
|
1739
1885
|
|
|
1740
1886
|
bool has_logits = true;
|
|
1741
1887
|
bool has_embd = cparams.embeddings;
|
|
@@ -1746,52 +1892,18 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|
|
1746
1892
|
has_embd = true;
|
|
1747
1893
|
}
|
|
1748
1894
|
|
|
1749
|
-
// Check which sampling modes are needed for the current batch.
|
|
1750
|
-
// TODO: avoid this branching by working with the worst-case
|
|
1751
|
-
bool has_sampling = false;
|
|
1752
|
-
bool cpu_logits = false;
|
|
1753
|
-
|
|
1754
|
-
if (batch.logits) {
|
|
1755
|
-
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
1756
|
-
if (!batch.logits[i]) {
|
|
1757
|
-
continue;
|
|
1758
|
-
}
|
|
1759
|
-
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
|
1760
|
-
llama_seq_id seq_id = batch.seq_id[i][j];
|
|
1761
|
-
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
|
1762
|
-
has_sampling = true;
|
|
1763
|
-
} else {
|
|
1764
|
-
cpu_logits = true;
|
|
1765
|
-
}
|
|
1766
|
-
}
|
|
1767
|
-
}
|
|
1768
|
-
} else {
|
|
1769
|
-
// When batch.logits is nullptr (when loading state with a dummy batch),
|
|
1770
|
-
// allocate CPU logits.
|
|
1771
|
-
cpu_logits = true;
|
|
1772
|
-
}
|
|
1773
1895
|
|
|
1774
1896
|
size_t backend_float_count = 0;
|
|
1775
1897
|
size_t backend_token_count = 0;
|
|
1776
1898
|
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
|
|
1780
|
-
|
|
1781
|
-
// TODO: avoid this branching by working with the worst-case
|
|
1782
|
-
if (!has_sampling) {
|
|
1783
|
-
sampling.logits_size = 0;
|
|
1784
|
-
sampling.probs_size = 0;
|
|
1785
|
-
sampling.sampled_size = 0;
|
|
1786
|
-
sampling.candidates_size = 0;
|
|
1787
|
-
} else {
|
|
1788
|
-
sampling.logits_size = n_vocab*n_outputs_max;
|
|
1789
|
-
sampling.probs_size = n_vocab*n_outputs_max;
|
|
1790
|
-
sampling.sampled_size = n_outputs_max;
|
|
1791
|
-
sampling.candidates_size = n_vocab*n_outputs_max;
|
|
1899
|
+
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
|
1900
|
+
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
|
1792
1901
|
|
|
1793
|
-
|
|
1794
|
-
|
|
1902
|
+
// Allocate backend sampling output buffers if there are backend samplers configured.
|
|
1903
|
+
const bool has_sampling = !sampling.samplers.empty();
|
|
1904
|
+
if (has_sampling) {
|
|
1905
|
+
backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
|
|
1906
|
+
backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
|
|
1795
1907
|
}
|
|
1796
1908
|
|
|
1797
1909
|
if (output_ids.empty()) {
|
|
@@ -1801,7 +1913,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|
|
1801
1913
|
|
|
1802
1914
|
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
|
1803
1915
|
const size_t new_size =
|
|
1804
|
-
(
|
|
1916
|
+
(logits.size + embd.size + backend_float_count) * sizeof(float) +
|
|
1805
1917
|
( backend_token_count) * sizeof(llama_token);
|
|
1806
1918
|
|
|
1807
1919
|
// alloc only when more than the current capacity is required
|
|
@@ -1816,8 +1928,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|
|
1816
1928
|
|
|
1817
1929
|
// TODO: not needed?
|
|
1818
1930
|
buf_output = nullptr;
|
|
1819
|
-
logits = nullptr;
|
|
1820
|
-
embd = nullptr;
|
|
1931
|
+
logits.data = nullptr;
|
|
1932
|
+
embd.data = nullptr;
|
|
1821
1933
|
}
|
|
1822
1934
|
|
|
1823
1935
|
auto * buft = ggml_backend_cpu_buffer_type();
|
|
@@ -1836,35 +1948,27 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|
|
1836
1948
|
|
|
1837
1949
|
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
|
1838
1950
|
|
|
1839
|
-
logits = nullptr;
|
|
1840
|
-
embd = nullptr;
|
|
1841
|
-
|
|
1842
1951
|
size_t offset = 0;
|
|
1843
1952
|
uint8_t * base = (uint8_t *) output_base;
|
|
1844
1953
|
|
|
1845
|
-
logits =
|
|
1846
|
-
offset +=
|
|
1954
|
+
logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
|
|
1955
|
+
offset += logits.size * sizeof(float);
|
|
1847
1956
|
|
|
1848
|
-
embd = has_embd ? (float *) (base + offset) : nullptr;
|
|
1849
|
-
offset +=
|
|
1850
|
-
|
|
1851
|
-
sampling.logits = nullptr;
|
|
1852
|
-
sampling.probs = nullptr;
|
|
1853
|
-
sampling.sampled = nullptr;
|
|
1854
|
-
sampling.candidates = nullptr;
|
|
1957
|
+
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
|
1958
|
+
offset += embd.size * sizeof(float);
|
|
1855
1959
|
|
|
1856
1960
|
if (has_sampling) {
|
|
1857
|
-
sampling.logits = (float *) (base + offset);
|
|
1858
|
-
offset += sampling.
|
|
1961
|
+
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
|
1962
|
+
offset += sampling.logits.size * sizeof(float);
|
|
1859
1963
|
|
|
1860
|
-
sampling.probs = (float *) (base + offset);
|
|
1861
|
-
offset += sampling.
|
|
1964
|
+
sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
|
1965
|
+
offset += sampling.probs.size * sizeof(float);
|
|
1862
1966
|
|
|
1863
|
-
sampling.sampled = (llama_token *) (base + offset);
|
|
1864
|
-
offset += sampling.
|
|
1967
|
+
sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
|
|
1968
|
+
offset += sampling.sampled.size * sizeof(llama_token);
|
|
1865
1969
|
|
|
1866
|
-
sampling.candidates = (llama_token *) (base + offset);
|
|
1867
|
-
offset += sampling.
|
|
1970
|
+
sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
|
1971
|
+
offset += sampling.candidates.size * sizeof(llama_token);
|
|
1868
1972
|
|
|
1869
1973
|
// The count vectors keep track of the actual number of logits/probs/candidates
|
|
1870
1974
|
// copied from the backend for each output row.
|
|
@@ -1877,7 +1981,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|
|
1877
1981
|
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
|
1878
1982
|
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
|
1879
1983
|
|
|
1880
|
-
std::fill_n(sampling.sampled, sampling.
|
|
1984
|
+
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
|
1985
|
+
} else {
|
|
1986
|
+
sampling.logits = {nullptr, 0};
|
|
1987
|
+
sampling.probs = {nullptr, 0};
|
|
1988
|
+
sampling.sampled = {nullptr, 0};
|
|
1989
|
+
sampling.candidates = {nullptr, 0};
|
|
1990
|
+
|
|
1991
|
+
sampling.logits_count.clear();
|
|
1992
|
+
sampling.probs_count.clear();
|
|
1993
|
+
sampling.candidates_count.clear();
|
|
1881
1994
|
}
|
|
1882
1995
|
|
|
1883
1996
|
// set all ids as invalid (negative)
|
|
@@ -1896,49 +2009,42 @@ void llama_context::output_reorder() {
|
|
|
1896
2009
|
const uint64_t i0 = output_swaps[s].i0;
|
|
1897
2010
|
const uint64_t i1 = output_swaps[s].i1;
|
|
1898
2011
|
|
|
1899
|
-
if (
|
|
2012
|
+
if (logits.size > 0) {
|
|
1900
2013
|
for (uint64_t k = 0; k < n_vocab; k++) {
|
|
1901
|
-
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
|
|
2014
|
+
std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
|
|
1902
2015
|
}
|
|
1903
2016
|
}
|
|
1904
2017
|
|
|
1905
|
-
if (
|
|
2018
|
+
if (embd.size > 0) {
|
|
1906
2019
|
for (uint64_t k = 0; k < n_embd; k++) {
|
|
1907
|
-
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
|
2020
|
+
std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
|
|
1908
2021
|
}
|
|
1909
2022
|
}
|
|
1910
2023
|
|
|
1911
|
-
if (sampling.
|
|
2024
|
+
if (!sampling.samplers.empty()) {
|
|
2025
|
+
assert(sampling.logits.size > 0);
|
|
2026
|
+
assert(sampling.probs.size > 0);
|
|
2027
|
+
assert(sampling.candidates.size > 0);
|
|
2028
|
+
assert(sampling.sampled.size > 0);
|
|
2029
|
+
assert(sampling.logits_count.size() > 0);
|
|
2030
|
+
assert(sampling.probs_count.size() > 0);
|
|
2031
|
+
assert(sampling.candidates_count.size() > 0);
|
|
2032
|
+
|
|
1912
2033
|
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1913
|
-
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
|
|
2034
|
+
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
|
1914
2035
|
}
|
|
1915
|
-
}
|
|
1916
2036
|
|
|
1917
|
-
if (sampling.probs && sampling.probs_size > 0) {
|
|
1918
2037
|
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1919
|
-
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
|
|
2038
|
+
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
|
1920
2039
|
}
|
|
1921
|
-
}
|
|
1922
2040
|
|
|
1923
|
-
if (sampling.candidates && sampling.candidates_size > 0) {
|
|
1924
2041
|
for (uint64_t k = 0; k < n_vocab; ++k) {
|
|
1925
|
-
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
|
|
2042
|
+
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
|
1926
2043
|
}
|
|
1927
|
-
}
|
|
1928
2044
|
|
|
1929
|
-
|
|
1930
|
-
std::swap(sampling.
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
if (!sampling.logits_count.empty()) {
|
|
1934
|
-
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
|
1935
|
-
}
|
|
1936
|
-
|
|
1937
|
-
if (!sampling.probs_count.empty()) {
|
|
1938
|
-
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
|
1939
|
-
}
|
|
1940
|
-
|
|
1941
|
-
if (!sampling.candidates_count.empty()) {
|
|
2045
|
+
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
|
2046
|
+
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
|
2047
|
+
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
|
1942
2048
|
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
|
1943
2049
|
}
|
|
1944
2050
|
}
|
|
@@ -1951,11 +2057,13 @@ void llama_context::output_reorder() {
|
|
|
1951
2057
|
//
|
|
1952
2058
|
|
|
1953
2059
|
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
|
1954
|
-
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
|
2060
|
+
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
|
|
1955
2061
|
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
|
1956
2062
|
}
|
|
1957
2063
|
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
|
1958
|
-
|
|
2064
|
+
for (const auto & lora : model.loras) {
|
|
2065
|
+
res += lora->get_n_nodes();
|
|
2066
|
+
}
|
|
1959
2067
|
return res;
|
|
1960
2068
|
}
|
|
1961
2069
|
|
|
@@ -1977,7 +2085,7 @@ ggml_cgraph * llama_context::graph_reserve(
|
|
|
1977
2085
|
|
|
1978
2086
|
ggml_backend_sched_reset(sched.get());
|
|
1979
2087
|
|
|
1980
|
-
// when the scheduler is reset, we
|
|
2088
|
+
// when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that
|
|
1981
2089
|
gf_res_prev->reset();
|
|
1982
2090
|
|
|
1983
2091
|
// store the n_outputs as it is, and restore it afterwards
|
|
@@ -2037,8 +2145,8 @@ llm_graph_params llama_context::graph_params(
|
|
|
2037
2145
|
/*.gtype =*/ gtype,
|
|
2038
2146
|
/*.sched =*/ sched.get(),
|
|
2039
2147
|
/*.backend_cpu =*/ backend_cpu,
|
|
2040
|
-
/*.cvec =*/
|
|
2041
|
-
/*.loras =*/
|
|
2148
|
+
/*.cvec =*/ cvec.get(),
|
|
2149
|
+
/*.loras =*/ loras.get(),
|
|
2042
2150
|
/*.mctx =*/ mctx,
|
|
2043
2151
|
/*.cross =*/ &cross,
|
|
2044
2152
|
/*.samplers =*/ sampling.samplers,
|
|
@@ -2085,13 +2193,6 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
|
|
2085
2193
|
ggml_set_name(cur, name);
|
|
2086
2194
|
}
|
|
2087
2195
|
|
|
2088
|
-
if (!cparams.offload_kqv) {
|
|
2089
|
-
if (strcmp(name, "kqv_merged_cont") == 0) {
|
|
2090
|
-
// all nodes between the KV store and the attention output are run on the CPU
|
|
2091
|
-
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
|
2092
|
-
}
|
|
2093
|
-
}
|
|
2094
|
-
|
|
2095
2196
|
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
|
|
2096
2197
|
// FIXME: fix in ggml_backend_sched
|
|
2097
2198
|
const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
|
|
@@ -2443,63 +2544,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
2443
2544
|
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
|
2444
2545
|
}
|
|
2445
2546
|
|
|
2446
|
-
// write output ids
|
|
2447
|
-
{
|
|
2448
|
-
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
|
2449
|
-
|
|
2450
|
-
const auto n_outputs = this->n_outputs;
|
|
2451
|
-
const auto & output_ids = this->output_ids;
|
|
2452
|
-
|
|
2453
|
-
std::vector<int32_t> w_output_pos;
|
|
2454
|
-
|
|
2455
|
-
w_output_pos.resize(n_outputs);
|
|
2456
|
-
|
|
2457
|
-
// build a more compact representation of the output ids
|
|
2458
|
-
for (size_t i = 0; i < n_batch(); ++i) {
|
|
2459
|
-
// map an output id to a position in the batch
|
|
2460
|
-
int64_t pos = output_ids[i];
|
|
2461
|
-
if (pos >= 0) {
|
|
2462
|
-
GGML_ASSERT(pos < n_outputs);
|
|
2463
|
-
w_output_pos[pos] = i;
|
|
2464
|
-
}
|
|
2465
|
-
}
|
|
2466
|
-
|
|
2467
|
-
io.write(&n_outputs, sizeof(n_outputs));
|
|
2468
|
-
|
|
2469
|
-
if (n_outputs) {
|
|
2470
|
-
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
|
2471
|
-
}
|
|
2472
|
-
}
|
|
2473
|
-
|
|
2474
|
-
// write logits
|
|
2475
|
-
{
|
|
2476
|
-
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
|
2477
|
-
|
|
2478
|
-
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
|
2479
|
-
|
|
2480
|
-
io.write(&logits_size, sizeof(logits_size));
|
|
2481
|
-
|
|
2482
|
-
if (logits_size) {
|
|
2483
|
-
io.write(logits, logits_size * sizeof(float));
|
|
2484
|
-
}
|
|
2485
|
-
}
|
|
2486
|
-
|
|
2487
|
-
// write embeddings
|
|
2488
|
-
{
|
|
2489
|
-
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
|
2490
|
-
|
|
2491
|
-
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
|
2492
|
-
|
|
2493
|
-
io.write(&embd_size, sizeof(embd_size));
|
|
2494
|
-
|
|
2495
|
-
if (embd_size) {
|
|
2496
|
-
io.write(embd, embd_size * sizeof(float));
|
|
2497
|
-
}
|
|
2498
|
-
}
|
|
2499
|
-
|
|
2500
|
-
// TODO: handle sampling buffers and samplers state ?
|
|
2501
|
-
// https://github.com/ggml-org/llama.cpp/pull/17004
|
|
2502
|
-
|
|
2503
2547
|
if (memory != nullptr) {
|
|
2504
2548
|
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
|
2505
2549
|
memory->state_write(io);
|
|
@@ -2525,73 +2569,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
2525
2569
|
// TODO: add more info which needs to be identical but which is not verified otherwise
|
|
2526
2570
|
}
|
|
2527
2571
|
|
|
2528
|
-
// read output ids
|
|
2529
|
-
{
|
|
2530
|
-
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
|
2531
|
-
|
|
2532
|
-
auto n_outputs = this->n_outputs;
|
|
2533
|
-
io.read_to(&n_outputs, sizeof(n_outputs));
|
|
2534
|
-
|
|
2535
|
-
// Create a dummy batch for state loading.
|
|
2536
|
-
llama_batch dummy_batch = {};
|
|
2537
|
-
dummy_batch.n_tokens = 0;
|
|
2538
|
-
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
|
|
2539
|
-
throw std::runtime_error("could not reserve outputs");
|
|
2540
|
-
}
|
|
2541
|
-
|
|
2542
|
-
std::vector<int32_t> output_pos;
|
|
2543
|
-
|
|
2544
|
-
if (n_outputs) {
|
|
2545
|
-
output_pos.resize(n_outputs);
|
|
2546
|
-
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
|
2547
|
-
|
|
2548
|
-
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
|
2549
|
-
int32_t id = output_pos[i];
|
|
2550
|
-
if ((uint32_t) id >= n_batch()) {
|
|
2551
|
-
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
|
2552
|
-
}
|
|
2553
|
-
this->output_ids[id] = i;
|
|
2554
|
-
}
|
|
2555
|
-
|
|
2556
|
-
this->n_outputs = n_outputs;
|
|
2557
|
-
}
|
|
2558
|
-
}
|
|
2559
|
-
|
|
2560
|
-
// read logits
|
|
2561
|
-
{
|
|
2562
|
-
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
|
2563
|
-
|
|
2564
|
-
uint64_t logits_size;
|
|
2565
|
-
io.read_to(&logits_size, sizeof(logits_size));
|
|
2566
|
-
|
|
2567
|
-
if (this->logits_size < logits_size) {
|
|
2568
|
-
throw std::runtime_error("logits buffer too small");
|
|
2569
|
-
}
|
|
2570
|
-
|
|
2571
|
-
if (logits_size) {
|
|
2572
|
-
io.read_to(this->logits, logits_size * sizeof(float));
|
|
2573
|
-
}
|
|
2574
|
-
}
|
|
2575
|
-
|
|
2576
|
-
// read embeddings
|
|
2577
|
-
{
|
|
2578
|
-
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
|
2579
|
-
|
|
2580
|
-
uint64_t embd_size;
|
|
2581
|
-
io.read_to(&embd_size, sizeof(embd_size));
|
|
2582
|
-
|
|
2583
|
-
if (this->embd_size < embd_size) {
|
|
2584
|
-
throw std::runtime_error("embeddings buffer too small");
|
|
2585
|
-
}
|
|
2586
|
-
|
|
2587
|
-
if (embd_size) {
|
|
2588
|
-
io.read_to(this->embd, embd_size * sizeof(float));
|
|
2589
|
-
}
|
|
2590
|
-
}
|
|
2591
|
-
|
|
2592
|
-
// TODO: handle sampling buffers and samplers state ?
|
|
2593
|
-
// https://github.com/ggml-org/llama.cpp/pull/17004
|
|
2594
|
-
|
|
2595
2572
|
if (memory) {
|
|
2596
2573
|
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
|
2597
2574
|
|
|
@@ -2724,6 +2701,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
|
|
|
2724
2701
|
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
|
2725
2702
|
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
|
2726
2703
|
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
|
2704
|
+
llama_set_param(model->cls_norm, param_filter, param_filter_ud);
|
|
2727
2705
|
|
|
2728
2706
|
for (struct llama_layer & layer : model->layers) {
|
|
2729
2707
|
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
|
@@ -2780,7 +2758,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2780
2758
|
}
|
|
2781
2759
|
|
|
2782
2760
|
// reserve output buffer
|
|
2783
|
-
if (output_reserve(n_outputs_all
|
|
2761
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
2784
2762
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
|
2785
2763
|
GGML_ABORT("TODO: handle this error");
|
|
2786
2764
|
};
|
|
@@ -2815,7 +2793,7 @@ void llama_context::opt_epoch_iter(
|
|
|
2815
2793
|
};
|
|
2816
2794
|
ctx_compute_opt = ggml_init(params);
|
|
2817
2795
|
}
|
|
2818
|
-
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->
|
|
2796
|
+
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
|
|
2819
2797
|
ggml_opt_alloc(opt_ctx, train);
|
|
2820
2798
|
|
|
2821
2799
|
res->set_inputs(&ubatch);
|
|
@@ -2957,19 +2935,23 @@ llama_context * llama_init_from_model(
|
|
|
2957
2935
|
|
|
2958
2936
|
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
|
2959
2937
|
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
|
2960
|
-
|
|
2961
|
-
|
|
2962
|
-
|
|
2963
|
-
|
|
2938
|
+
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
|
2939
|
+
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
|
2940
|
+
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
|
2941
|
+
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
|
|
2942
|
+
return nullptr;
|
|
2943
|
+
}
|
|
2964
2944
|
}
|
|
2965
2945
|
}
|
|
2966
2946
|
|
|
2967
2947
|
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
|
2968
2948
|
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
|
2969
|
-
|
|
2970
|
-
|
|
2971
|
-
|
|
2972
|
-
|
|
2949
|
+
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
|
2950
|
+
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
|
2951
|
+
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
|
|
2952
|
+
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
|
|
2953
|
+
return nullptr;
|
|
2954
|
+
}
|
|
2973
2955
|
}
|
|
2974
2956
|
}
|
|
2975
2957
|
|
|
@@ -3161,37 +3143,43 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
|
|
3161
3143
|
return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
|
|
3162
3144
|
}
|
|
3163
3145
|
|
|
3164
|
-
|
|
3165
|
-
|
|
3166
|
-
|
|
3167
|
-
|
|
3168
|
-
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
|
|
3172
|
-
|
|
3146
|
+
struct ggml_cgraph * llama_graph_reserve(
|
|
3147
|
+
struct llama_context * ctx,
|
|
3148
|
+
uint32_t n_tokens,
|
|
3149
|
+
uint32_t n_seqs,
|
|
3150
|
+
uint32_t n_outputs) {
|
|
3151
|
+
auto * memory = ctx->get_memory();
|
|
3152
|
+
llama_memory_context_ptr mctx;
|
|
3153
|
+
if (memory) {
|
|
3154
|
+
mctx = memory->init_full();
|
|
3155
|
+
}
|
|
3156
|
+
return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get());
|
|
3173
3157
|
}
|
|
3174
3158
|
|
|
3175
|
-
|
|
3159
|
+
// llama adapter API
|
|
3160
|
+
|
|
3161
|
+
int32_t llama_set_adapters_lora(
|
|
3176
3162
|
llama_context * ctx,
|
|
3177
|
-
llama_adapter_lora
|
|
3178
|
-
|
|
3163
|
+
llama_adapter_lora ** adapters,
|
|
3164
|
+
size_t n_adapters,
|
|
3165
|
+
float * scales) {
|
|
3166
|
+
if (adapters == nullptr || scales == nullptr) {
|
|
3167
|
+
GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call");
|
|
3168
|
+
}
|
|
3179
3169
|
|
|
3180
|
-
|
|
3181
|
-
}
|
|
3170
|
+
ctx->set_adapters_lora(adapters, n_adapters, scales);
|
|
3182
3171
|
|
|
3183
|
-
|
|
3184
|
-
ctx->clear_adapter_lora();
|
|
3172
|
+
return 0;
|
|
3185
3173
|
}
|
|
3186
3174
|
|
|
3187
|
-
int32_t
|
|
3175
|
+
int32_t llama_set_adapter_cvec(
|
|
3188
3176
|
llama_context * ctx,
|
|
3189
|
-
|
|
3190
|
-
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
|
|
3194
|
-
bool res = ctx->
|
|
3177
|
+
const float * data,
|
|
3178
|
+
size_t len,
|
|
3179
|
+
int32_t n_embd,
|
|
3180
|
+
int32_t il_start,
|
|
3181
|
+
int32_t il_end) {
|
|
3182
|
+
bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end);
|
|
3195
3183
|
|
|
3196
3184
|
return res ? 0 : -1;
|
|
3197
3185
|
}
|