whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include <algorithm>
|
|
9
9
|
#include <cassert>
|
|
10
10
|
#include <cmath>
|
|
11
|
+
#include <cstring>
|
|
11
12
|
#include <limits>
|
|
12
13
|
#include <map>
|
|
13
14
|
#include <stdexcept>
|
|
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
|
|
|
37
38
|
|
|
38
39
|
const uint32_t n_layer_kv = hparams.n_layer_kv();
|
|
39
40
|
|
|
41
|
+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
|
42
|
+
struct ggml_backend_buft_comparator {
|
|
43
|
+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
|
44
|
+
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
|
|
45
|
+
}
|
|
46
|
+
};
|
|
47
|
+
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
|
48
|
+
|
|
40
49
|
// create a context for each buffer type
|
|
41
|
-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
42
50
|
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
|
43
51
|
auto it = ctx_map.find(buft);
|
|
44
52
|
if (it == ctx_map.end()) {
|
|
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
|
|
|
53
61
|
return nullptr;
|
|
54
62
|
}
|
|
55
63
|
|
|
56
|
-
ctx_map
|
|
57
|
-
ctxs.emplace_back(ctx);
|
|
64
|
+
ctx_map.emplace(buft, ctx);
|
|
58
65
|
|
|
59
66
|
return ctx;
|
|
60
67
|
}
|
|
61
68
|
|
|
62
|
-
return it->second;
|
|
69
|
+
return it->second.get();
|
|
63
70
|
};
|
|
64
71
|
|
|
65
72
|
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
|
|
@@ -90,6 +97,8 @@ llama_kv_cache::llama_kv_cache(
|
|
|
90
97
|
__func__, hparams.n_embd_v_gqa_max());
|
|
91
98
|
}
|
|
92
99
|
|
|
100
|
+
const bool is_mla = hparams.is_mla();
|
|
101
|
+
|
|
93
102
|
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
|
94
103
|
if (!hparams.has_kv(il)) {
|
|
95
104
|
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
|
@@ -123,21 +132,21 @@ llama_kv_cache::llama_kv_cache(
|
|
|
123
132
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
|
124
133
|
}
|
|
125
134
|
|
|
126
|
-
|
|
127
|
-
|
|
135
|
+
const bool has_k = true;
|
|
136
|
+
const bool has_v = !is_mla;
|
|
128
137
|
|
|
129
|
-
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
|
130
|
-
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
|
138
|
+
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
|
|
139
|
+
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
|
|
131
140
|
|
|
132
|
-
ggml_format_name(k, "cache_k_l%d", il);
|
|
133
|
-
ggml_format_name(v, "cache_v_l%d", il);
|
|
141
|
+
has_k && ggml_format_name(k, "cache_k_l%d", il);
|
|
142
|
+
has_v && ggml_format_name(v, "cache_v_l%d", il);
|
|
134
143
|
|
|
135
144
|
std::vector<ggml_tensor *> k_stream;
|
|
136
145
|
std::vector<ggml_tensor *> v_stream;
|
|
137
146
|
|
|
138
147
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
|
139
|
-
k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
|
140
|
-
v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
|
148
|
+
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
|
|
149
|
+
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
|
|
141
150
|
}
|
|
142
151
|
|
|
143
152
|
map_layer_ids[il] = layers.size();
|
|
@@ -170,11 +179,16 @@ llama_kv_cache::llama_kv_cache(
|
|
|
170
179
|
}
|
|
171
180
|
|
|
172
181
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
173
|
-
for (auto
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
182
|
+
for (auto & [buft, ctx] : ctx_map) {
|
|
183
|
+
ggml_backend_buffer_t buf;
|
|
184
|
+
if (model.hparams.no_alloc) {
|
|
185
|
+
buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
|
|
186
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
187
|
+
t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it
|
|
188
|
+
}
|
|
189
|
+
} else {
|
|
190
|
+
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer
|
|
191
|
+
}
|
|
178
192
|
if (!buf) {
|
|
179
193
|
throw std::runtime_error("failed to allocate buffer for kv cache");
|
|
180
194
|
}
|
|
@@ -182,7 +196,7 @@ llama_kv_cache::llama_kv_cache(
|
|
|
182
196
|
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
|
183
197
|
|
|
184
198
|
ggml_backend_buffer_clear(buf, 0);
|
|
185
|
-
|
|
199
|
+
ctxs_bufs.emplace_back(std::move(ctx), buf);
|
|
186
200
|
}
|
|
187
201
|
|
|
188
202
|
{
|
|
@@ -206,7 +220,7 @@ void llama_kv_cache::clear(bool data) {
|
|
|
206
220
|
}
|
|
207
221
|
|
|
208
222
|
if (data) {
|
|
209
|
-
for (auto & buf :
|
|
223
|
+
for (auto & [_, buf] : ctxs_bufs) {
|
|
210
224
|
ggml_backend_buffer_clear(buf.get(), 0);
|
|
211
225
|
}
|
|
212
226
|
}
|
|
@@ -337,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
337
351
|
llama_pos pos = v_cells[s0].pos_get(i);
|
|
338
352
|
llama_pos shift = v_cells[s0].get_shift(i);
|
|
339
353
|
|
|
354
|
+
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
|
|
355
|
+
|
|
340
356
|
if (shift != 0) {
|
|
341
357
|
pos -= shift;
|
|
342
358
|
assert(pos >= 0);
|
|
@@ -348,6 +364,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
348
364
|
if (shift != 0) {
|
|
349
365
|
v_cells[s1].pos_add(i, shift);
|
|
350
366
|
}
|
|
367
|
+
|
|
368
|
+
v_cells[s1].ext_set(i, ext);
|
|
351
369
|
}
|
|
352
370
|
}
|
|
353
371
|
|
|
@@ -382,6 +400,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
|
|
382
400
|
|
|
383
401
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
384
402
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
403
|
+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
|
385
404
|
|
|
386
405
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
387
406
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
|
@@ -426,6 +445,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
|
|
426
445
|
|
|
427
446
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
428
447
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
448
|
+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
|
429
449
|
|
|
430
450
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
431
451
|
|
|
@@ -475,9 +495,18 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
475
495
|
|
|
476
496
|
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
|
|
477
497
|
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
|
478
|
-
for (const
|
|
479
|
-
|
|
498
|
+
for (const auto & [ctx, buf] : ctxs_bufs) {
|
|
499
|
+
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get());
|
|
500
|
+
|
|
501
|
+
if (hparams.no_alloc) {
|
|
502
|
+
GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr);
|
|
503
|
+
ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
|
|
504
|
+
} else {
|
|
505
|
+
// GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
|
|
506
|
+
ret[buft] += ggml_backend_buffer_get_size(buf.get());
|
|
507
|
+
}
|
|
480
508
|
}
|
|
509
|
+
|
|
481
510
|
return ret;
|
|
482
511
|
}
|
|
483
512
|
|
|
@@ -554,7 +583,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
|
|
|
554
583
|
break;
|
|
555
584
|
}
|
|
556
585
|
|
|
557
|
-
//
|
|
586
|
+
// remember the position that we found
|
|
558
587
|
res.push_back(sinfo_new);
|
|
559
588
|
|
|
560
589
|
// store the old state of the cells in the recovery stack
|
|
@@ -623,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
|
|
|
623
652
|
const auto & layer = layers[il];
|
|
624
653
|
|
|
625
654
|
ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
|
|
626
|
-
|
|
655
|
+
|
|
656
|
+
if (layer.v_stream[ssrc]) {
|
|
657
|
+
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
|
|
658
|
+
}
|
|
627
659
|
}
|
|
628
660
|
}
|
|
629
661
|
}
|
|
@@ -828,7 +860,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
|
|
828
860
|
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
|
829
861
|
|
|
830
862
|
// SWA mask
|
|
831
|
-
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
863
|
+
if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
832
864
|
can_use = true;
|
|
833
865
|
}
|
|
834
866
|
}
|
|
@@ -899,6 +931,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|
|
899
931
|
|
|
900
932
|
cells.pos_set(idx, ubatch.pos[i]);
|
|
901
933
|
|
|
934
|
+
if (ubatch.is_pos_2d()) {
|
|
935
|
+
llama_kv_cell_ext ext {
|
|
936
|
+
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
|
|
937
|
+
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
|
|
938
|
+
};
|
|
939
|
+
cells.ext_set(idx, ext);
|
|
940
|
+
}
|
|
941
|
+
|
|
902
942
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
903
943
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
|
904
944
|
}
|
|
@@ -934,6 +974,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|
|
934
974
|
}
|
|
935
975
|
|
|
936
976
|
bool llama_kv_cache::get_can_shift() const {
|
|
977
|
+
// Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot.
|
|
978
|
+
if (model.arch == LLM_ARCH_STEP35) {
|
|
979
|
+
return false;
|
|
980
|
+
}
|
|
981
|
+
if (hparams.n_pos_per_embd() > 1) {
|
|
982
|
+
return false;
|
|
983
|
+
}
|
|
937
984
|
return true;
|
|
938
985
|
}
|
|
939
986
|
|
|
@@ -960,10 +1007,14 @@ bool llama_kv_cache::get_has_shift() const {
|
|
|
960
1007
|
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
|
961
1008
|
uint32_t result = 0;
|
|
962
1009
|
|
|
1010
|
+
// pad the n_kv value so that the graph remains constant across batches and can be reused
|
|
1011
|
+
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
|
|
1012
|
+
const uint32_t n_pad_cur = std::max(n_pad, 256u);
|
|
1013
|
+
|
|
963
1014
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
|
964
1015
|
const auto & cells = v_cells[sinfo.strm[s]];
|
|
965
1016
|
|
|
966
|
-
result = std::max(std::min(cells.size(), std::max(
|
|
1017
|
+
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
|
|
967
1018
|
}
|
|
968
1019
|
|
|
969
1020
|
return result;
|
|
@@ -982,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
|
|
|
982
1033
|
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
|
983
1034
|
|
|
984
1035
|
return ggml_view_4d(ctx, k,
|
|
985
|
-
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
|
|
986
|
-
ggml_row_size(k->type, hparams.n_embd_head_k),
|
|
1036
|
+
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
|
|
1037
|
+
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
|
|
987
1038
|
ggml_row_size(k->type, n_embd_k_gqa),
|
|
988
1039
|
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
|
|
989
1040
|
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
|
|
@@ -1005,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
|
|
1005
1056
|
if (!v_trans) {
|
|
1006
1057
|
// note: v->nb[1] <= v->nb[2]
|
|
1007
1058
|
return ggml_view_4d(ctx, v,
|
|
1008
|
-
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
|
1009
|
-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
|
1059
|
+
hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns,
|
|
1060
|
+
ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1]
|
|
1010
1061
|
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
|
|
1011
1062
|
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
|
|
1012
1063
|
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
|
|
@@ -1014,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
|
|
1014
1065
|
|
|
1015
1066
|
// note: v->nb[1] > v->nb[2]
|
|
1016
1067
|
return ggml_view_4d(ctx, v,
|
|
1017
|
-
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
|
1018
|
-
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
|
|
1068
|
+
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns,
|
|
1069
|
+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1]
|
|
1019
1070
|
ggml_row_size(v->type, kv_size), // v->nb[2]
|
|
1020
1071
|
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
|
|
1021
1072
|
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
|
@@ -1201,78 +1252,236 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
|
|
1201
1252
|
}
|
|
1202
1253
|
}
|
|
1203
1254
|
|
|
1204
|
-
|
|
1205
|
-
const
|
|
1255
|
+
struct args_set_input_kq_mask {
|
|
1256
|
+
const llama_hparams & hparams;
|
|
1257
|
+
const llama_ubatch * ubatch;
|
|
1206
1258
|
|
|
1207
|
-
|
|
1208
|
-
|
|
1259
|
+
const std::vector<llama_kv_cells> & v_cells;
|
|
1260
|
+
const std::vector<uint32_t> & seq_to_stream;
|
|
1209
1261
|
|
|
1210
|
-
|
|
1211
|
-
|
|
1262
|
+
uint32_t n_swa;
|
|
1263
|
+
llama_swa_type swa_type;
|
|
1212
1264
|
|
|
1213
|
-
|
|
1265
|
+
int64_t n_kv;
|
|
1266
|
+
int64_t n_stream;
|
|
1267
|
+
int64_t n_tps;
|
|
1268
|
+
};
|
|
1214
1269
|
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
|
1220
|
-
|
|
1221
|
-
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
|
1222
|
-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
1223
|
-
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
|
1224
|
-
// Causal mask:
|
|
1225
|
-
// xxx-------
|
|
1226
|
-
// xxxx------
|
|
1227
|
-
// xxxxx-----
|
|
1228
|
-
// Non-causal mask:
|
|
1229
|
-
// xxxxx-----
|
|
1230
|
-
// xxxxx-----
|
|
1231
|
-
// xxxxx-----
|
|
1232
|
-
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
|
1233
|
-
// TODO: optimize this section
|
|
1234
|
-
for (uint32_t h = 0; h < 1; ++h) {
|
|
1235
|
-
for (uint32_t s = 0; s < n_stream; ++s) {
|
|
1236
|
-
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
|
1237
|
-
const uint32_t i = s*n_tps + ii;
|
|
1270
|
+
template<bool causal, bool swa, bool is_2d, bool alibi>
|
|
1271
|
+
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
1272
|
+
//const auto & hparams = args.hparams;
|
|
1273
|
+
const auto & ubatch = args.ubatch;
|
|
1238
1274
|
|
|
1239
|
-
|
|
1275
|
+
const auto & v_cells = args.v_cells;
|
|
1276
|
+
const auto & seq_to_stream = args.seq_to_stream;
|
|
1240
1277
|
|
|
1241
|
-
|
|
1278
|
+
const uint32_t n_swa = args.n_swa;
|
|
1279
|
+
const llama_swa_type swa_type = args.swa_type;
|
|
1242
1280
|
|
|
1243
|
-
|
|
1281
|
+
const int64_t n_kv = args.n_kv;
|
|
1282
|
+
const int64_t n_stream = args.n_stream;
|
|
1283
|
+
const int64_t n_tps = args.n_tps;
|
|
1244
1284
|
|
|
1245
|
-
|
|
1285
|
+
// the min position in the batch for each sequence
|
|
1286
|
+
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
|
1287
|
+
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
|
1246
1288
|
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1289
|
+
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
|
1290
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
1291
|
+
|
|
1292
|
+
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
|
1296
|
+
// bookkeeping of the KQ mask cells that could change for other tokens of the same sequence
|
|
1297
|
+
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
|
|
1298
|
+
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
|
|
1299
|
+
|
|
1300
|
+
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
|
1301
|
+
const uint32_t i = s*n_tps + ii;
|
|
1251
1302
|
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1303
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
1304
|
+
|
|
1305
|
+
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
|
|
1306
|
+
|
|
1307
|
+
llama_pos p0 = -1;
|
|
1308
|
+
const llama_pos p1 = ubatch->pos[i];
|
|
1309
|
+
|
|
1310
|
+
// for M-RoPE
|
|
1311
|
+
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
|
1312
|
+
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
|
1313
|
+
|
|
1314
|
+
const uint64_t idst = n_kv*i;
|
|
1315
|
+
|
|
1316
|
+
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
|
|
1317
|
+
// the only cells that could change are the ones that are with similar positions as the
|
|
1318
|
+
// ones in the batch (i.e. due to causal masking, SWA, etc.)
|
|
1319
|
+
// keep track of those cells and shortcut the loop to save time
|
|
1320
|
+
// note: this optimization is not compatible with Alibi position encoding
|
|
1321
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
|
|
1322
|
+
bool prev = false;
|
|
1323
|
+
|
|
1324
|
+
auto & idxs = seq_idxs[seq_id];
|
|
1325
|
+
|
|
1326
|
+
if (!alibi) {
|
|
1327
|
+
if (seq_srct.find(seq_id) != seq_srct.end()) {
|
|
1328
|
+
const uint32_t srct = seq_srct[seq_id];
|
|
1329
|
+
|
|
1330
|
+
const uint64_t idst_prev = n_kv*srct;
|
|
1331
|
+
|
|
1332
|
+
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
|
|
1333
|
+
|
|
1334
|
+
prev = true;
|
|
1335
|
+
} else {
|
|
1336
|
+
idxs.clear();
|
|
1337
|
+
idxs.reserve(ubatch->n_tokens + n_swa + 32);
|
|
1338
|
+
|
|
1339
|
+
seq_srct[seq_id] = i;
|
|
1340
|
+
}
|
|
1341
|
+
}
|
|
1342
|
+
|
|
1343
|
+
for (uint32_t jj = 0; jj < n_kv; ++jj) {
|
|
1344
|
+
uint32_t j = jj;
|
|
1345
|
+
|
|
1346
|
+
// we have an exiting mask for this sequence -> update just seq_idxs
|
|
1347
|
+
if (!alibi) {
|
|
1348
|
+
if (prev) {
|
|
1349
|
+
if (jj >= idxs.size()) {
|
|
1350
|
+
break;
|
|
1351
|
+
}
|
|
1352
|
+
|
|
1353
|
+
j = idxs[jj];
|
|
1255
1354
|
}
|
|
1355
|
+
}
|
|
1256
1356
|
|
|
1257
|
-
|
|
1357
|
+
if (cells.is_empty(j)) {
|
|
1358
|
+
goto skip;
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
// mask the token if not the same sequence
|
|
1362
|
+
if (!cells.seq_has(j, seq_id)) {
|
|
1363
|
+
goto skip;
|
|
1364
|
+
}
|
|
1258
1365
|
|
|
1366
|
+
p0 = cells.pos_get(j);
|
|
1367
|
+
|
|
1368
|
+
if (!alibi) {
|
|
1369
|
+
if (!prev) {
|
|
1370
|
+
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
|
|
1371
|
+
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
|
|
1372
|
+
idxs.push_back(j);
|
|
1373
|
+
}
|
|
1374
|
+
}
|
|
1375
|
+
}
|
|
1376
|
+
|
|
1377
|
+
if (causal) {
|
|
1259
1378
|
// mask future tokens
|
|
1260
|
-
if (
|
|
1261
|
-
|
|
1379
|
+
if (p0 > p1) {
|
|
1380
|
+
goto skip;
|
|
1381
|
+
}
|
|
1382
|
+
|
|
1383
|
+
// M-RoPE causal mask
|
|
1384
|
+
if (is_2d) {
|
|
1385
|
+
if (p0 == p1) {
|
|
1386
|
+
const auto & p0_ext = cells.ext_get(j);
|
|
1387
|
+
|
|
1388
|
+
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
|
1389
|
+
goto skip;
|
|
1390
|
+
}
|
|
1391
|
+
}
|
|
1262
1392
|
}
|
|
1393
|
+
}
|
|
1263
1394
|
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1395
|
+
// apply SWA if any
|
|
1396
|
+
if (swa) {
|
|
1397
|
+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
1398
|
+
goto skip;
|
|
1267
1399
|
}
|
|
1400
|
+
}
|
|
1268
1401
|
|
|
1269
|
-
|
|
1402
|
+
if (alibi) {
|
|
1403
|
+
data[idst + j] = -std::abs(p0 - p1);
|
|
1404
|
+
} else {
|
|
1405
|
+
data[idst + j] = 0.0f;
|
|
1270
1406
|
}
|
|
1407
|
+
|
|
1408
|
+
continue;
|
|
1409
|
+
skip:
|
|
1410
|
+
data[idst + j] = -INFINITY;
|
|
1271
1411
|
}
|
|
1272
1412
|
}
|
|
1273
1413
|
}
|
|
1274
1414
|
}
|
|
1275
1415
|
|
|
1416
|
+
template<bool causal, bool swa, bool is_2d>
|
|
1417
|
+
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
1418
|
+
const bool alibi = args.hparams.use_alibi;
|
|
1419
|
+
if (alibi) {
|
|
1420
|
+
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
|
|
1421
|
+
} else {
|
|
1422
|
+
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
|
|
1423
|
+
}
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
template<bool causal, bool swa>
|
|
1427
|
+
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
1428
|
+
const bool is_2d = args.ubatch->is_pos_2d();
|
|
1429
|
+
if (is_2d) {
|
|
1430
|
+
set_input_kq_mask_impl<causal, swa, true> (args, data);
|
|
1431
|
+
} else {
|
|
1432
|
+
set_input_kq_mask_impl<causal, swa, false>(args, data);
|
|
1433
|
+
}
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
template<bool causal>
|
|
1437
|
+
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
1438
|
+
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
|
1439
|
+
if (swa) {
|
|
1440
|
+
set_input_kq_mask_impl<causal, true> (args, data);
|
|
1441
|
+
} else {
|
|
1442
|
+
set_input_kq_mask_impl<causal, false>(args, data);
|
|
1443
|
+
}
|
|
1444
|
+
}
|
|
1445
|
+
|
|
1446
|
+
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
1447
|
+
const uint32_t n_tokens = ubatch->n_tokens;
|
|
1448
|
+
|
|
1449
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
1450
|
+
float * data = (float *) dst->data;
|
|
1451
|
+
|
|
1452
|
+
const int64_t n_kv = dst->ne[0];
|
|
1453
|
+
const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
|
|
1454
|
+
|
|
1455
|
+
GGML_ASSERT(n_tokens%n_stream == 0);
|
|
1456
|
+
|
|
1457
|
+
// n_tps == n_tokens_per_stream
|
|
1458
|
+
const int64_t n_tps = n_tokens/n_stream;
|
|
1459
|
+
|
|
1460
|
+
//const int64_t t_start = ggml_time_us();
|
|
1461
|
+
|
|
1462
|
+
const args_set_input_kq_mask args = {
|
|
1463
|
+
/*.hparams =*/ hparams,
|
|
1464
|
+
/*.ubatch =*/ ubatch,
|
|
1465
|
+
/*.v_cells =*/ v_cells,
|
|
1466
|
+
/*.seq_to_stream =*/ seq_to_stream,
|
|
1467
|
+
/*.n_swa =*/ n_swa,
|
|
1468
|
+
/*.swa_type =*/ swa_type,
|
|
1469
|
+
/*.n_kv =*/ n_kv,
|
|
1470
|
+
/*.n_stream =*/ n_stream,
|
|
1471
|
+
/*.n_tps =*/ n_tps,
|
|
1472
|
+
};
|
|
1473
|
+
|
|
1474
|
+
if (causal_attn) {
|
|
1475
|
+
set_input_kq_mask_impl<true> (args, data);
|
|
1476
|
+
} else {
|
|
1477
|
+
set_input_kq_mask_impl<false>(args, data);
|
|
1478
|
+
}
|
|
1479
|
+
|
|
1480
|
+
//const int64_t t_end = ggml_time_us();
|
|
1481
|
+
|
|
1482
|
+
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
|
|
1483
|
+
}
|
|
1484
|
+
|
|
1276
1485
|
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
1277
1486
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
1278
1487
|
|
|
@@ -1301,7 +1510,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
|
|
|
1301
1510
|
size_t llama_kv_cache::total_size() const {
|
|
1302
1511
|
size_t size = 0;
|
|
1303
1512
|
|
|
1304
|
-
for (const auto & buf :
|
|
1513
|
+
for (const auto & [_, buf] : ctxs_bufs) {
|
|
1305
1514
|
size += ggml_backend_buffer_get_size(buf.get());
|
|
1306
1515
|
}
|
|
1307
1516
|
|
|
@@ -1322,7 +1531,7 @@ size_t llama_kv_cache::size_v_bytes() const {
|
|
|
1322
1531
|
size_t size_v_bytes = 0;
|
|
1323
1532
|
|
|
1324
1533
|
for (const auto & layer : layers) {
|
|
1325
|
-
size_v_bytes += ggml_nbytes(layer.v);
|
|
1534
|
+
size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
|
|
1326
1535
|
}
|
|
1327
1536
|
|
|
1328
1537
|
return size_v_bytes;
|
|
@@ -1335,15 +1544,17 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|
|
1335
1544
|
ggml_tensor * shift,
|
|
1336
1545
|
ggml_tensor * factors,
|
|
1337
1546
|
float freq_base,
|
|
1338
|
-
float freq_scale
|
|
1547
|
+
float freq_scale,
|
|
1548
|
+
uint32_t il) const {
|
|
1339
1549
|
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
|
1340
1550
|
|
|
1341
|
-
const auto & yarn_ext_factor
|
|
1342
|
-
const auto & yarn_beta_fast
|
|
1343
|
-
const auto & yarn_beta_slow
|
|
1551
|
+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
|
1552
|
+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
|
1553
|
+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
|
1554
|
+
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
|
1344
1555
|
|
|
1345
|
-
const auto & n_rot = hparams.n_rot;
|
|
1346
|
-
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
|
|
1556
|
+
const auto & n_rot = hparams.n_rot(il);
|
|
1557
|
+
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
|
1347
1558
|
// @ngxson : this is a workaround
|
|
1348
1559
|
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
|
1349
1560
|
// a normal RoPE should work, we just need to use the correct ordering
|
|
@@ -1351,12 +1562,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|
|
1351
1562
|
? LLAMA_ROPE_TYPE_NEOX
|
|
1352
1563
|
: hparams.rope_type;
|
|
1353
1564
|
|
|
1354
|
-
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
|
1355
|
-
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
|
1356
|
-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
|
|
1357
|
-
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
|
1358
|
-
: cparams.yarn_attn_factor;
|
|
1359
|
-
|
|
1360
1565
|
ggml_tensor * tmp;
|
|
1361
1566
|
|
|
1362
1567
|
if (ggml_is_quantized(cur->type)) {
|
|
@@ -1402,9 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|
|
1402
1607
|
auto * ctx = res->get_ctx();
|
|
1403
1608
|
auto * gf = res->get_gf();
|
|
1404
1609
|
|
|
1405
|
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
1406
|
-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
1407
|
-
|
|
1408
1610
|
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
|
1409
1611
|
|
|
1410
1612
|
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
|
@@ -1418,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|
|
1418
1620
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
1419
1621
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
1420
1622
|
|
|
1623
|
+
const auto n_rot = hparams.n_rot(il);
|
|
1624
|
+
const auto n_embd_head_k = hparams.n_embd_head_k(il);
|
|
1625
|
+
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
|
1626
|
+
|
|
1421
1627
|
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
|
1422
1628
|
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
1423
1629
|
|
|
@@ -1425,12 +1631,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|
|
1425
1631
|
|
|
1426
1632
|
ggml_tensor * k =
|
|
1427
1633
|
ggml_view_3d(ctx, layer.k,
|
|
1428
|
-
|
|
1634
|
+
n_rot, n_head_kv, get_size()*n_stream,
|
|
1429
1635
|
ggml_row_size(layer.k->type, n_embd_head_k),
|
|
1430
1636
|
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
|
1431
|
-
|
|
1637
|
+
ggml_row_size(layer.k->type, n_embd_nope));
|
|
1432
1638
|
|
|
1433
|
-
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
|
1639
|
+
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
|
|
1434
1640
|
|
|
1435
1641
|
ggml_build_forward_expand(gf, cur);
|
|
1436
1642
|
}
|
|
@@ -1440,10 +1646,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|
|
1440
1646
|
return gf;
|
|
1441
1647
|
}
|
|
1442
1648
|
|
|
1443
|
-
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|
1444
|
-
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
|
|
1445
|
-
}
|
|
1446
|
-
|
|
1447
1649
|
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
1448
1650
|
GGML_UNUSED(flags);
|
|
1449
1651
|
|
|
@@ -1518,9 +1720,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
|
|
1518
1720
|
|
|
1519
1721
|
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
|
1520
1722
|
|
|
1723
|
+
slot_info sinfo;
|
|
1724
|
+
|
|
1521
1725
|
bool res = true;
|
|
1522
|
-
res = res && state_read_meta(io, strm, cell_count, seq_id);
|
|
1523
|
-
res = res && state_read_data(io, strm, cell_count);
|
|
1726
|
+
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
|
|
1727
|
+
res = res && state_read_data(io, strm, cell_count, sinfo);
|
|
1524
1728
|
|
|
1525
1729
|
if (!res) {
|
|
1526
1730
|
if (seq_id == -1) {
|
|
@@ -1554,6 +1758,11 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
|
|
1554
1758
|
io.write(&pos, sizeof(pos));
|
|
1555
1759
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
|
1556
1760
|
|
|
1761
|
+
if (hparams.n_pos_per_embd() > 1) {
|
|
1762
|
+
const llama_kv_cell_ext ext = cells.ext_get(i);
|
|
1763
|
+
io.write(&ext, sizeof(ext));
|
|
1764
|
+
}
|
|
1765
|
+
|
|
1557
1766
|
for (const auto & seq_id : seq_ids) {
|
|
1558
1767
|
io.write(&seq_id, sizeof(seq_id));
|
|
1559
1768
|
}
|
|
@@ -1570,8 +1779,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1570
1779
|
io.write(&v_trans, sizeof(v_trans));
|
|
1571
1780
|
io.write(&n_layer, sizeof(n_layer));
|
|
1572
1781
|
|
|
1573
|
-
std::vector<uint8_t> tmp_buf;
|
|
1574
|
-
|
|
1575
1782
|
// Iterate and write all the keys first, each row is a cell
|
|
1576
1783
|
// Get whole range at a time
|
|
1577
1784
|
for (const auto & layer : layers) {
|
|
@@ -1589,7 +1796,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1589
1796
|
const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
|
|
1590
1797
|
io.write(&k_size_row, sizeof(k_size_row));
|
|
1591
1798
|
|
|
1592
|
-
// Read each range of cells of k_size length
|
|
1799
|
+
// Read each range of cells of k_size length and write out
|
|
1593
1800
|
for (const auto & range : cr.data) {
|
|
1594
1801
|
const size_t range_size = range.second - range.first;
|
|
1595
1802
|
const size_t buf_size = range_size * k_size_row;
|
|
@@ -1604,6 +1811,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1604
1811
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
1605
1812
|
|
|
1606
1813
|
auto * v = layer.v_stream[cr.strm];
|
|
1814
|
+
if (!v) {
|
|
1815
|
+
continue;
|
|
1816
|
+
}
|
|
1607
1817
|
|
|
1608
1818
|
// Write value type
|
|
1609
1819
|
const int32_t v_type_i = (int32_t) v->type;
|
|
@@ -1613,7 +1823,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1613
1823
|
const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
|
|
1614
1824
|
io.write(&v_size_row, sizeof(v_size_row));
|
|
1615
1825
|
|
|
1616
|
-
// Read each range of cells of v_size length
|
|
1826
|
+
// Read each range of cells of v_size length and write out
|
|
1617
1827
|
for (const auto & range : cr.data) {
|
|
1618
1828
|
const size_t range_size = range.second - range.first;
|
|
1619
1829
|
const size_t buf_size = range_size * v_size_row;
|
|
@@ -1630,6 +1840,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1630
1840
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
1631
1841
|
|
|
1632
1842
|
auto * v = layer.v_stream[cr.strm];
|
|
1843
|
+
if (!v) {
|
|
1844
|
+
continue;
|
|
1845
|
+
}
|
|
1633
1846
|
|
|
1634
1847
|
// Write value type
|
|
1635
1848
|
const int32_t v_type_i = (int32_t) v->type;
|
|
@@ -1644,7 +1857,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1644
1857
|
|
|
1645
1858
|
// For each row, we get the element values of each cell
|
|
1646
1859
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
1647
|
-
// Read each range of cells of v_size_el length
|
|
1860
|
+
// Read each range of cells of v_size_el length and write out
|
|
1648
1861
|
for (const auto & range : cr.data) {
|
|
1649
1862
|
const size_t range_size = range.second - range.first;
|
|
1650
1863
|
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
|
@@ -1656,7 +1869,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1656
1869
|
}
|
|
1657
1870
|
}
|
|
1658
1871
|
|
|
1659
|
-
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
|
1872
|
+
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
|
|
1660
1873
|
auto & cells = v_cells[strm];
|
|
1661
1874
|
auto & head = v_heads[strm];
|
|
1662
1875
|
|
|
@@ -1682,6 +1895,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1682
1895
|
return false;
|
|
1683
1896
|
}
|
|
1684
1897
|
|
|
1898
|
+
if (hparams.n_pos_per_embd() > 1) {
|
|
1899
|
+
llama_kv_cell_ext ext;
|
|
1900
|
+
io.read_to(&ext, sizeof(ext));
|
|
1901
|
+
|
|
1902
|
+
ubatch.pos[i + ubatch.n_tokens] = ext.y;
|
|
1903
|
+
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
|
|
1904
|
+
}
|
|
1905
|
+
|
|
1685
1906
|
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
|
1686
1907
|
{
|
|
1687
1908
|
llama_seq_id seq_id;
|
|
@@ -1693,28 +1914,26 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1693
1914
|
ubatch.seq_id[i] = &dest_seq_id;
|
|
1694
1915
|
}
|
|
1695
1916
|
|
|
1696
|
-
|
|
1917
|
+
sinfo = find_slot(ubatch, false);
|
|
1697
1918
|
if (sinfo.empty()) {
|
|
1698
1919
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
1699
1920
|
return false;
|
|
1700
1921
|
}
|
|
1701
1922
|
|
|
1923
|
+
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
|
|
1924
|
+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
1702
1925
|
apply_ubatch(sinfo, ubatch);
|
|
1703
1926
|
|
|
1704
|
-
|
|
1927
|
+
LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
|
|
1705
1928
|
|
|
1706
|
-
//
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
|
1715
|
-
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
|
1716
|
-
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
|
1717
|
-
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
|
1929
|
+
// DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
|
|
1930
|
+
GGML_ASSERT(sinfo.n_stream() == 1);
|
|
1931
|
+
GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
|
|
1932
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1933
|
+
const uint32_t idx = sinfo.idxs[0][i];
|
|
1934
|
+
GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
|
|
1935
|
+
GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
|
|
1936
|
+
}
|
|
1718
1937
|
} else {
|
|
1719
1938
|
// whole KV cache restore
|
|
1720
1939
|
|
|
@@ -1734,6 +1953,12 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1734
1953
|
|
|
1735
1954
|
cells.pos_set(i, pos);
|
|
1736
1955
|
|
|
1956
|
+
if (hparams.n_pos_per_embd() > 1) {
|
|
1957
|
+
llama_kv_cell_ext ext;
|
|
1958
|
+
io.read_to(&ext, sizeof(ext));
|
|
1959
|
+
cells.ext_set(i, ext);
|
|
1960
|
+
}
|
|
1961
|
+
|
|
1737
1962
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
|
1738
1963
|
llama_seq_id seq_id;
|
|
1739
1964
|
io.read_to(&seq_id, sizeof(seq_id));
|
|
@@ -1747,15 +1972,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1747
1972
|
}
|
|
1748
1973
|
}
|
|
1749
1974
|
|
|
1975
|
+
// Create contiguous slot_info for whole cache restore
|
|
1976
|
+
sinfo.s0 = strm;
|
|
1977
|
+
sinfo.s1 = strm;
|
|
1978
|
+
sinfo.resize(1);
|
|
1979
|
+
sinfo.strm[0] = strm;
|
|
1980
|
+
sinfo.idxs[0].resize(cell_count);
|
|
1981
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1982
|
+
sinfo.idxs[0][i] = i;
|
|
1983
|
+
}
|
|
1984
|
+
|
|
1750
1985
|
head = 0;
|
|
1751
1986
|
}
|
|
1752
1987
|
|
|
1753
1988
|
return true;
|
|
1754
1989
|
}
|
|
1755
1990
|
|
|
1756
|
-
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
|
1991
|
+
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
|
|
1757
1992
|
auto & cells = v_cells[strm];
|
|
1758
|
-
auto & head = v_heads[strm];
|
|
1759
1993
|
|
|
1760
1994
|
uint32_t v_trans;
|
|
1761
1995
|
uint32_t n_layer;
|
|
@@ -1805,8 +2039,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1805
2039
|
}
|
|
1806
2040
|
|
|
1807
2041
|
if (cell_count) {
|
|
1808
|
-
|
|
1809
|
-
|
|
2042
|
+
if (sinfo.is_contiguous()) {
|
|
2043
|
+
// Fast path: contiguous cells, single memcpy
|
|
2044
|
+
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
|
2045
|
+
} else {
|
|
2046
|
+
// Slow path: scatter to non-contiguous positions
|
|
2047
|
+
const void * src = io.read(cell_count * k_size_row);
|
|
2048
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
2049
|
+
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
|
2050
|
+
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
|
2051
|
+
}
|
|
2052
|
+
}
|
|
1810
2053
|
}
|
|
1811
2054
|
}
|
|
1812
2055
|
|
|
@@ -1817,6 +2060,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1817
2060
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
1818
2061
|
|
|
1819
2062
|
auto * v = layer.v_stream[strm];
|
|
2063
|
+
if (!v) {
|
|
2064
|
+
continue;
|
|
2065
|
+
}
|
|
1820
2066
|
|
|
1821
2067
|
// Read type of value
|
|
1822
2068
|
int32_t v_type_i_ref;
|
|
@@ -1837,8 +2083,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1837
2083
|
}
|
|
1838
2084
|
|
|
1839
2085
|
if (cell_count) {
|
|
1840
|
-
|
|
1841
|
-
|
|
2086
|
+
if (sinfo.is_contiguous()) {
|
|
2087
|
+
// Fast path: contiguous cells, single memcpy
|
|
2088
|
+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
|
2089
|
+
} else {
|
|
2090
|
+
// Slow path: scatter to non-contiguous positions
|
|
2091
|
+
const void * src = io.read(cell_count * v_size_row);
|
|
2092
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
2093
|
+
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
|
2094
|
+
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
|
2095
|
+
}
|
|
2096
|
+
}
|
|
1842
2097
|
}
|
|
1843
2098
|
}
|
|
1844
2099
|
} else {
|
|
@@ -1849,6 +2104,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1849
2104
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
1850
2105
|
|
|
1851
2106
|
auto * v = layer.v_stream[strm];
|
|
2107
|
+
if (!v) {
|
|
2108
|
+
continue;
|
|
2109
|
+
}
|
|
1852
2110
|
|
|
1853
2111
|
// Read type of value
|
|
1854
2112
|
int32_t v_type_i_ref;
|
|
@@ -1877,10 +2135,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1877
2135
|
}
|
|
1878
2136
|
|
|
1879
2137
|
if (cell_count) {
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
const
|
|
1883
|
-
|
|
2138
|
+
if (sinfo.is_contiguous()) {
|
|
2139
|
+
// Fast path: contiguous cells
|
|
2140
|
+
const uint32_t h = sinfo.head();
|
|
2141
|
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
2142
|
+
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
|
2143
|
+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
|
2144
|
+
}
|
|
2145
|
+
} else {
|
|
2146
|
+
// Slow path: scatter to non-contiguous positions
|
|
2147
|
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
2148
|
+
const void * src = io.read(cell_count * v_size_el);
|
|
2149
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
2150
|
+
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
|
2151
|
+
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
|
2152
|
+
}
|
|
2153
|
+
}
|
|
1884
2154
|
}
|
|
1885
2155
|
}
|
|
1886
2156
|
}
|
|
@@ -2013,8 +2283,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
|
2013
2283
|
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
2014
2284
|
kv->set_input_pos_bucket(dst, ubatch);
|
|
2015
2285
|
}
|
|
2016
|
-
|
|
2017
|
-
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
|
|
2018
|
-
// the FA kernels require padding to avoid extra runtime boundary checks
|
|
2019
|
-
return cparams.flash_attn ? 256u : 32u;
|
|
2020
|
-
}
|