whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include <algorithm>
|
|
9
9
|
#include <cassert>
|
|
10
10
|
#include <cmath>
|
|
11
|
+
#include <cstring>
|
|
11
12
|
#include <limits>
|
|
12
13
|
#include <map>
|
|
13
14
|
#include <stdexcept>
|
|
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
|
|
|
37
38
|
|
|
38
39
|
const uint32_t n_layer_kv = hparams.n_layer_kv();
|
|
39
40
|
|
|
41
|
+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
|
|
42
|
+
struct ggml_backend_buft_comparator {
|
|
43
|
+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
|
|
44
|
+
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
|
|
45
|
+
}
|
|
46
|
+
};
|
|
47
|
+
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
|
|
48
|
+
|
|
40
49
|
// create a context for each buffer type
|
|
41
|
-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
42
50
|
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
|
43
51
|
auto it = ctx_map.find(buft);
|
|
44
52
|
if (it == ctx_map.end()) {
|
|
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
|
|
|
53
61
|
return nullptr;
|
|
54
62
|
}
|
|
55
63
|
|
|
56
|
-
ctx_map
|
|
57
|
-
ctxs.emplace_back(ctx);
|
|
64
|
+
ctx_map.emplace(buft, ctx);
|
|
58
65
|
|
|
59
66
|
return ctx;
|
|
60
67
|
}
|
|
61
68
|
|
|
62
|
-
return it->second;
|
|
69
|
+
return it->second.get();
|
|
63
70
|
};
|
|
64
71
|
|
|
65
72
|
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
|
|
@@ -123,11 +130,8 @@ llama_kv_cache::llama_kv_cache(
|
|
|
123
130
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
|
124
131
|
}
|
|
125
132
|
|
|
126
|
-
ggml_tensor * k;
|
|
127
|
-
ggml_tensor * v;
|
|
128
|
-
|
|
129
|
-
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
|
130
|
-
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
|
133
|
+
ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
|
134
|
+
ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
|
131
135
|
|
|
132
136
|
ggml_format_name(k, "cache_k_l%d", il);
|
|
133
137
|
ggml_format_name(v, "cache_v_l%d", il);
|
|
@@ -170,11 +174,16 @@ llama_kv_cache::llama_kv_cache(
|
|
|
170
174
|
}
|
|
171
175
|
|
|
172
176
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
173
|
-
for (auto
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
177
|
+
for (auto & [buft, ctx] : ctx_map) {
|
|
178
|
+
ggml_backend_buffer_t buf;
|
|
179
|
+
if (model.hparams.no_alloc) {
|
|
180
|
+
buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
|
|
181
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
|
|
182
|
+
t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it
|
|
183
|
+
}
|
|
184
|
+
} else {
|
|
185
|
+
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer
|
|
186
|
+
}
|
|
178
187
|
if (!buf) {
|
|
179
188
|
throw std::runtime_error("failed to allocate buffer for kv cache");
|
|
180
189
|
}
|
|
@@ -182,7 +191,7 @@ llama_kv_cache::llama_kv_cache(
|
|
|
182
191
|
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
|
183
192
|
|
|
184
193
|
ggml_backend_buffer_clear(buf, 0);
|
|
185
|
-
|
|
194
|
+
ctxs_bufs.emplace_back(std::move(ctx), buf);
|
|
186
195
|
}
|
|
187
196
|
|
|
188
197
|
{
|
|
@@ -206,7 +215,7 @@ void llama_kv_cache::clear(bool data) {
|
|
|
206
215
|
}
|
|
207
216
|
|
|
208
217
|
if (data) {
|
|
209
|
-
for (auto & buf :
|
|
218
|
+
for (auto & [_, buf] : ctxs_bufs) {
|
|
210
219
|
ggml_backend_buffer_clear(buf.get(), 0);
|
|
211
220
|
}
|
|
212
221
|
}
|
|
@@ -337,6 +346,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
337
346
|
llama_pos pos = v_cells[s0].pos_get(i);
|
|
338
347
|
llama_pos shift = v_cells[s0].get_shift(i);
|
|
339
348
|
|
|
349
|
+
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
|
|
350
|
+
|
|
340
351
|
if (shift != 0) {
|
|
341
352
|
pos -= shift;
|
|
342
353
|
assert(pos >= 0);
|
|
@@ -348,6 +359,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
348
359
|
if (shift != 0) {
|
|
349
360
|
v_cells[s1].pos_add(i, shift);
|
|
350
361
|
}
|
|
362
|
+
|
|
363
|
+
v_cells[s1].ext_set(i, ext);
|
|
351
364
|
}
|
|
352
365
|
}
|
|
353
366
|
|
|
@@ -382,6 +395,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
|
|
382
395
|
|
|
383
396
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
384
397
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
398
|
+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
|
385
399
|
|
|
386
400
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
387
401
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
|
@@ -426,6 +440,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
|
|
426
440
|
|
|
427
441
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
428
442
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
443
|
+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
|
429
444
|
|
|
430
445
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
431
446
|
|
|
@@ -475,9 +490,18 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
475
490
|
|
|
476
491
|
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
|
|
477
492
|
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
|
478
|
-
for (const
|
|
479
|
-
|
|
493
|
+
for (const auto & [ctx, buf] : ctxs_bufs) {
|
|
494
|
+
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get());
|
|
495
|
+
|
|
496
|
+
if (hparams.no_alloc) {
|
|
497
|
+
GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr);
|
|
498
|
+
ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
|
|
499
|
+
} else {
|
|
500
|
+
// GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
|
|
501
|
+
ret[buft] += ggml_backend_buffer_get_size(buf.get());
|
|
502
|
+
}
|
|
480
503
|
}
|
|
504
|
+
|
|
481
505
|
return ret;
|
|
482
506
|
}
|
|
483
507
|
|
|
@@ -899,6 +923,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|
|
899
923
|
|
|
900
924
|
cells.pos_set(idx, ubatch.pos[i]);
|
|
901
925
|
|
|
926
|
+
if (ubatch.is_pos_2d()) {
|
|
927
|
+
llama_kv_cell_ext ext {
|
|
928
|
+
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
|
|
929
|
+
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
|
|
930
|
+
};
|
|
931
|
+
cells.ext_set(idx, ext);
|
|
932
|
+
}
|
|
933
|
+
|
|
902
934
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
903
935
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
|
904
936
|
}
|
|
@@ -960,10 +992,14 @@ bool llama_kv_cache::get_has_shift() const {
|
|
|
960
992
|
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
|
961
993
|
uint32_t result = 0;
|
|
962
994
|
|
|
995
|
+
// pad the n_kv value so that the graph remains constant across batches and can be reused
|
|
996
|
+
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
|
|
997
|
+
const uint32_t n_pad_cur = std::max(n_pad, 256u);
|
|
998
|
+
|
|
963
999
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
|
964
1000
|
const auto & cells = v_cells[sinfo.strm[s]];
|
|
965
1001
|
|
|
966
|
-
result = std::max(std::min(cells.size(), std::max(
|
|
1002
|
+
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
|
|
967
1003
|
}
|
|
968
1004
|
|
|
969
1005
|
return result;
|
|
@@ -1213,8 +1249,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
1213
1249
|
GGML_ASSERT(n_tokens%n_stream == 0);
|
|
1214
1250
|
|
|
1215
1251
|
// n_tps == n_tokens_per_stream
|
|
1216
|
-
const int64_t n_tps
|
|
1217
|
-
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
|
|
1252
|
+
const int64_t n_tps = n_tokens/n_stream;
|
|
1218
1253
|
|
|
1219
1254
|
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
|
1220
1255
|
|
|
@@ -1242,7 +1277,12 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
1242
1277
|
|
|
1243
1278
|
const llama_pos p1 = ubatch->pos[i];
|
|
1244
1279
|
|
|
1245
|
-
|
|
1280
|
+
// for M-RoPE
|
|
1281
|
+
const bool is_2d = ubatch->is_pos_2d();
|
|
1282
|
+
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
|
1283
|
+
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
|
1284
|
+
|
|
1285
|
+
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
|
1246
1286
|
|
|
1247
1287
|
for (uint32_t j = 0; j < n_kv; ++j) {
|
|
1248
1288
|
if (cells.is_empty(j)) {
|
|
@@ -1261,6 +1301,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
1261
1301
|
continue;
|
|
1262
1302
|
}
|
|
1263
1303
|
|
|
1304
|
+
// M-RoPE causal mask
|
|
1305
|
+
if (causal_attn && is_2d && p0 == p1) {
|
|
1306
|
+
const auto & p0_ext = cells.ext_get(j);
|
|
1307
|
+
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
|
1308
|
+
continue;
|
|
1309
|
+
}
|
|
1310
|
+
}
|
|
1311
|
+
|
|
1264
1312
|
// apply SWA if any
|
|
1265
1313
|
if (is_masked_swa(p0, p1)) {
|
|
1266
1314
|
continue;
|
|
@@ -1301,7 +1349,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
|
|
|
1301
1349
|
size_t llama_kv_cache::total_size() const {
|
|
1302
1350
|
size_t size = 0;
|
|
1303
1351
|
|
|
1304
|
-
for (const auto & buf :
|
|
1352
|
+
for (const auto & [_, buf] : ctxs_bufs) {
|
|
1305
1353
|
size += ggml_backend_buffer_get_size(buf.get());
|
|
1306
1354
|
}
|
|
1307
1355
|
|
|
@@ -1338,12 +1386,13 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|
|
1338
1386
|
float freq_scale) const {
|
|
1339
1387
|
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
|
1340
1388
|
|
|
1341
|
-
const auto & yarn_ext_factor
|
|
1342
|
-
const auto & yarn_beta_fast
|
|
1343
|
-
const auto & yarn_beta_slow
|
|
1389
|
+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
|
1390
|
+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
|
1391
|
+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
|
1392
|
+
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
|
1344
1393
|
|
|
1345
1394
|
const auto & n_rot = hparams.n_rot;
|
|
1346
|
-
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
|
|
1395
|
+
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
|
1347
1396
|
// @ngxson : this is a workaround
|
|
1348
1397
|
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
|
1349
1398
|
// a normal RoPE should work, we just need to use the correct ordering
|
|
@@ -1351,12 +1400,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|
|
1351
1400
|
? LLAMA_ROPE_TYPE_NEOX
|
|
1352
1401
|
: hparams.rope_type;
|
|
1353
1402
|
|
|
1354
|
-
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
|
1355
|
-
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
|
1356
|
-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
|
|
1357
|
-
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
|
1358
|
-
: cparams.yarn_attn_factor;
|
|
1359
|
-
|
|
1360
1403
|
ggml_tensor * tmp;
|
|
1361
1404
|
|
|
1362
1405
|
if (ggml_is_quantized(cur->type)) {
|
|
@@ -1518,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
|
|
1518
1561
|
|
|
1519
1562
|
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
|
1520
1563
|
|
|
1564
|
+
slot_info sinfo;
|
|
1565
|
+
|
|
1521
1566
|
bool res = true;
|
|
1522
|
-
res = res && state_read_meta(io, strm, cell_count, seq_id);
|
|
1523
|
-
res = res && state_read_data(io, strm, cell_count);
|
|
1567
|
+
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
|
|
1568
|
+
res = res && state_read_data(io, strm, cell_count, sinfo);
|
|
1524
1569
|
|
|
1525
1570
|
if (!res) {
|
|
1526
1571
|
if (seq_id == -1) {
|
|
@@ -1554,6 +1599,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
|
|
1554
1599
|
io.write(&pos, sizeof(pos));
|
|
1555
1600
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
|
1556
1601
|
|
|
1602
|
+
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
|
|
1603
|
+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
1604
|
+
|
|
1557
1605
|
for (const auto & seq_id : seq_ids) {
|
|
1558
1606
|
io.write(&seq_id, sizeof(seq_id));
|
|
1559
1607
|
}
|
|
@@ -1656,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
1656
1704
|
}
|
|
1657
1705
|
}
|
|
1658
1706
|
|
|
1659
|
-
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
|
1707
|
+
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
|
|
1660
1708
|
auto & cells = v_cells[strm];
|
|
1661
1709
|
auto & head = v_heads[strm];
|
|
1662
1710
|
|
|
@@ -1693,28 +1741,26 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1693
1741
|
ubatch.seq_id[i] = &dest_seq_id;
|
|
1694
1742
|
}
|
|
1695
1743
|
|
|
1696
|
-
|
|
1744
|
+
sinfo = find_slot(ubatch, false);
|
|
1697
1745
|
if (sinfo.empty()) {
|
|
1698
1746
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
1699
1747
|
return false;
|
|
1700
1748
|
}
|
|
1701
1749
|
|
|
1750
|
+
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
|
|
1751
|
+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
1702
1752
|
apply_ubatch(sinfo, ubatch);
|
|
1703
1753
|
|
|
1704
|
-
|
|
1754
|
+
LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
|
|
1705
1755
|
|
|
1706
|
-
//
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
|
1715
|
-
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
|
1716
|
-
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
|
1717
|
-
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
|
1756
|
+
// DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
|
|
1757
|
+
GGML_ASSERT(sinfo.n_stream() == 1);
|
|
1758
|
+
GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
|
|
1759
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1760
|
+
const uint32_t idx = sinfo.idxs[0][i];
|
|
1761
|
+
GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
|
|
1762
|
+
GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
|
|
1763
|
+
}
|
|
1718
1764
|
} else {
|
|
1719
1765
|
// whole KV cache restore
|
|
1720
1766
|
|
|
@@ -1747,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1747
1793
|
}
|
|
1748
1794
|
}
|
|
1749
1795
|
|
|
1796
|
+
// Create contiguous slot_info for whole cache restore
|
|
1797
|
+
sinfo.s0 = strm;
|
|
1798
|
+
sinfo.s1 = strm;
|
|
1799
|
+
sinfo.resize(1);
|
|
1800
|
+
sinfo.strm[0] = strm;
|
|
1801
|
+
sinfo.idxs[0].resize(cell_count);
|
|
1802
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1803
|
+
sinfo.idxs[0][i] = i;
|
|
1804
|
+
}
|
|
1805
|
+
|
|
1750
1806
|
head = 0;
|
|
1751
1807
|
}
|
|
1752
1808
|
|
|
1753
1809
|
return true;
|
|
1754
1810
|
}
|
|
1755
1811
|
|
|
1756
|
-
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
|
1812
|
+
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
|
|
1757
1813
|
auto & cells = v_cells[strm];
|
|
1758
|
-
auto & head = v_heads[strm];
|
|
1759
1814
|
|
|
1760
1815
|
uint32_t v_trans;
|
|
1761
1816
|
uint32_t n_layer;
|
|
@@ -1805,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1805
1860
|
}
|
|
1806
1861
|
|
|
1807
1862
|
if (cell_count) {
|
|
1808
|
-
|
|
1809
|
-
|
|
1863
|
+
if (sinfo.is_contiguous()) {
|
|
1864
|
+
// Fast path: contiguous cells, single memcpy
|
|
1865
|
+
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
|
1866
|
+
} else {
|
|
1867
|
+
// Slow path: scatter to non-contiguous positions
|
|
1868
|
+
const void * src = io.read(cell_count * k_size_row);
|
|
1869
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1870
|
+
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
|
1871
|
+
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
|
1872
|
+
}
|
|
1873
|
+
}
|
|
1810
1874
|
}
|
|
1811
1875
|
}
|
|
1812
1876
|
|
|
@@ -1837,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1837
1901
|
}
|
|
1838
1902
|
|
|
1839
1903
|
if (cell_count) {
|
|
1840
|
-
|
|
1841
|
-
|
|
1904
|
+
if (sinfo.is_contiguous()) {
|
|
1905
|
+
// Fast path: contiguous cells, single memcpy
|
|
1906
|
+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
|
1907
|
+
} else {
|
|
1908
|
+
// Slow path: scatter to non-contiguous positions
|
|
1909
|
+
const void * src = io.read(cell_count * v_size_row);
|
|
1910
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1911
|
+
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
|
1912
|
+
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
|
1913
|
+
}
|
|
1914
|
+
}
|
|
1842
1915
|
}
|
|
1843
1916
|
}
|
|
1844
1917
|
} else {
|
|
@@ -1877,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
1877
1950
|
}
|
|
1878
1951
|
|
|
1879
1952
|
if (cell_count) {
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
const
|
|
1883
|
-
|
|
1953
|
+
if (sinfo.is_contiguous()) {
|
|
1954
|
+
// Fast path: contiguous cells
|
|
1955
|
+
const uint32_t h = sinfo.head();
|
|
1956
|
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
1957
|
+
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
|
1958
|
+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
|
1959
|
+
}
|
|
1960
|
+
} else {
|
|
1961
|
+
// Slow path: scatter to non-contiguous positions
|
|
1962
|
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
1963
|
+
const void * src = io.read(cell_count * v_size_el);
|
|
1964
|
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
1965
|
+
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
|
1966
|
+
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
|
1967
|
+
}
|
|
1968
|
+
}
|
|
1884
1969
|
}
|
|
1885
1970
|
}
|
|
1886
1971
|
}
|
|
@@ -2013,8 +2098,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
|
2013
2098
|
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
2014
2099
|
kv->set_input_pos_bucket(dst, ubatch);
|
|
2015
2100
|
}
|
|
2016
|
-
|
|
2017
|
-
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
|
|
2018
|
-
// the FA kernels require padding to avoid extra runtime boundary checks
|
|
2019
|
-
return cparams.flash_attn ? 256u : 32u;
|
|
2020
|
-
}
|
|
@@ -19,8 +19,6 @@ struct llama_context;
|
|
|
19
19
|
|
|
20
20
|
class llama_kv_cache : public llama_memory_i {
|
|
21
21
|
public:
|
|
22
|
-
static uint32_t get_padding(const llama_cparams & cparams);
|
|
23
|
-
|
|
24
22
|
struct stream_copy_info {
|
|
25
23
|
bool empty() const {
|
|
26
24
|
assert(ssrc.size() == sdst.size());
|
|
@@ -74,6 +72,23 @@ public:
|
|
|
74
72
|
void clear() {
|
|
75
73
|
idxs.clear();
|
|
76
74
|
}
|
|
75
|
+
|
|
76
|
+
// check if indices are contiguous starting from head()
|
|
77
|
+
bool is_contiguous() const {
|
|
78
|
+
if (idxs.empty() || idxs[0].empty()) {
|
|
79
|
+
return true;
|
|
80
|
+
}
|
|
81
|
+
if (idxs.size() > 1) {
|
|
82
|
+
return false;
|
|
83
|
+
}
|
|
84
|
+
const uint32_t h = idxs[0][0];
|
|
85
|
+
for (size_t i = 0; i < idxs[0].size(); ++i) {
|
|
86
|
+
if (idxs[0][i] != h + i) {
|
|
87
|
+
return false;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
return true;
|
|
91
|
+
}
|
|
77
92
|
};
|
|
78
93
|
|
|
79
94
|
using slot_info_vec_t = std::vector<slot_info>;
|
|
@@ -217,8 +232,8 @@ private:
|
|
|
217
232
|
// this is the SWA type of the cache - not to be confused with the model SWA type
|
|
218
233
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
|
219
234
|
|
|
220
|
-
|
|
221
|
-
std::vector<ggml_backend_buffer_ptr
|
|
235
|
+
// ggml contexts for the KV cache along with the allocated backend buffers:
|
|
236
|
+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
|
222
237
|
|
|
223
238
|
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
|
224
239
|
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
|
@@ -266,8 +281,8 @@ private:
|
|
|
266
281
|
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
|
267
282
|
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
|
268
283
|
|
|
269
|
-
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
|
270
|
-
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
|
284
|
+
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
|
|
285
|
+
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
|
|
271
286
|
};
|
|
272
287
|
|
|
273
288
|
class llama_kv_cache_context : public llama_memory_context_i {
|
|
@@ -290,7 +305,7 @@ public:
|
|
|
290
305
|
bool do_shift,
|
|
291
306
|
stream_copy_info sc_info);
|
|
292
307
|
|
|
293
|
-
// used to create a batch
|
|
308
|
+
// used to create a batch processing context from a batch
|
|
294
309
|
llama_kv_cache_context(
|
|
295
310
|
llama_kv_cache * kv,
|
|
296
311
|
slot_info_vec_t sinfos,
|
|
@@ -5,9 +5,27 @@
|
|
|
5
5
|
|
|
6
6
|
#include <bitset>
|
|
7
7
|
#include <cassert>
|
|
8
|
-
#include <
|
|
9
|
-
#include <set>
|
|
8
|
+
#include <cstring>
|
|
10
9
|
#include <map>
|
|
10
|
+
#include <set>
|
|
11
|
+
#include <vector>
|
|
12
|
+
|
|
13
|
+
struct llama_kv_cell_ext {
|
|
14
|
+
// 2D spatial positions, typically used for M-RoPE
|
|
15
|
+
llama_pos x = 0;
|
|
16
|
+
llama_pos y = 0;
|
|
17
|
+
|
|
18
|
+
// return true if the current 2D spatial position is greater than other
|
|
19
|
+
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
|
|
20
|
+
return (y > oy) || (y == oy && x > ox);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
void reset() {
|
|
24
|
+
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
|
|
25
|
+
|
|
26
|
+
memset(this, 0, sizeof(*this));
|
|
27
|
+
}
|
|
28
|
+
};
|
|
11
29
|
|
|
12
30
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
|
13
31
|
// TODO: add unit tests
|
|
@@ -16,6 +34,7 @@ public:
|
|
|
16
34
|
void reset() {
|
|
17
35
|
for (uint32_t i = 0; i < pos.size(); ++i) {
|
|
18
36
|
pos[i] = -1;
|
|
37
|
+
ext[i].reset();
|
|
19
38
|
shift[i] = 0;
|
|
20
39
|
seq[i].reset();
|
|
21
40
|
}
|
|
@@ -43,6 +62,7 @@ public:
|
|
|
43
62
|
|
|
44
63
|
void resize(uint32_t n) {
|
|
45
64
|
pos.resize(n);
|
|
65
|
+
ext.resize(n);
|
|
46
66
|
shift.resize(n);
|
|
47
67
|
seq.resize(n);
|
|
48
68
|
|
|
@@ -108,6 +128,7 @@ public:
|
|
|
108
128
|
const auto idx = i + j;
|
|
109
129
|
|
|
110
130
|
res.pos[j] = pos[idx];
|
|
131
|
+
res.ext[j] = ext[idx];
|
|
111
132
|
res.seq[j] = seq[idx];
|
|
112
133
|
|
|
113
134
|
assert(shift[idx] == 0);
|
|
@@ -126,6 +147,7 @@ public:
|
|
|
126
147
|
const auto idx = idxs[j];
|
|
127
148
|
|
|
128
149
|
res.pos[j] = pos[idx];
|
|
150
|
+
res.ext[j] = ext[idx];
|
|
129
151
|
res.seq[j] = seq[idx];
|
|
130
152
|
|
|
131
153
|
assert(shift[idx] == 0);
|
|
@@ -154,6 +176,7 @@ public:
|
|
|
154
176
|
}
|
|
155
177
|
|
|
156
178
|
pos[idx] = other.pos[j];
|
|
179
|
+
ext[idx] = other.ext[j];
|
|
157
180
|
seq[idx] = other.seq[j];
|
|
158
181
|
|
|
159
182
|
if (pos[idx] != -1) {
|
|
@@ -184,6 +207,7 @@ public:
|
|
|
184
207
|
}
|
|
185
208
|
|
|
186
209
|
pos[idx] = other.pos[j];
|
|
210
|
+
ext[idx] = other.ext[j];
|
|
187
211
|
seq[idx] = other.seq[j];
|
|
188
212
|
|
|
189
213
|
if (pos[idx] != -1) {
|
|
@@ -203,6 +227,7 @@ public:
|
|
|
203
227
|
seq[i].reset();
|
|
204
228
|
|
|
205
229
|
pos[i] = -1;
|
|
230
|
+
ext[i].reset();
|
|
206
231
|
shift[i] = 0;
|
|
207
232
|
|
|
208
233
|
used.erase(i);
|
|
@@ -221,6 +246,7 @@ public:
|
|
|
221
246
|
|
|
222
247
|
if (seq[i].none()) {
|
|
223
248
|
pos[i] = -1;
|
|
249
|
+
ext[i].reset();
|
|
224
250
|
shift[i] = 0;
|
|
225
251
|
|
|
226
252
|
used.erase(i);
|
|
@@ -250,6 +276,7 @@ public:
|
|
|
250
276
|
seq[i].reset();
|
|
251
277
|
|
|
252
278
|
pos[i] = -1;
|
|
279
|
+
ext[i].reset();
|
|
253
280
|
shift[i] = 0;
|
|
254
281
|
|
|
255
282
|
used.erase(i);
|
|
@@ -340,6 +367,13 @@ public:
|
|
|
340
367
|
return pos[i];
|
|
341
368
|
}
|
|
342
369
|
|
|
370
|
+
const llama_kv_cell_ext & ext_get(uint32_t i) const {
|
|
371
|
+
assert(i < pos.size());
|
|
372
|
+
assert(pos[i] != -1);
|
|
373
|
+
|
|
374
|
+
return ext[i];
|
|
375
|
+
}
|
|
376
|
+
|
|
343
377
|
// note: call only if the cell is not empty
|
|
344
378
|
llama_pos get_shift(uint32_t i) const {
|
|
345
379
|
assert(i < pos.size());
|
|
@@ -368,6 +402,11 @@ public:
|
|
|
368
402
|
used.insert(i);
|
|
369
403
|
}
|
|
370
404
|
|
|
405
|
+
void ext_set(uint32_t i, llama_kv_cell_ext p) {
|
|
406
|
+
assert(i < ext.size());
|
|
407
|
+
ext[i] = p;
|
|
408
|
+
}
|
|
409
|
+
|
|
371
410
|
// pos[i] = pos[i] + d
|
|
372
411
|
// sets "has_shift" to true
|
|
373
412
|
// note: call only if the cell is not empty
|
|
@@ -424,6 +463,9 @@ private:
|
|
|
424
463
|
|
|
425
464
|
std::vector<llama_pos> pos;
|
|
426
465
|
|
|
466
|
+
// stores extra info per cell
|
|
467
|
+
std::vector<llama_kv_cell_ext> ext;
|
|
468
|
+
|
|
427
469
|
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
|
428
470
|
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
|
429
471
|
//
|
|
@@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|
|
73
73
|
// if all tokens are output, split by sequence
|
|
74
74
|
ubatch = balloc.split_seq(n_ubatch);
|
|
75
75
|
} else {
|
|
76
|
-
|
|
76
|
+
// TODO: non-sequential equal split can be done if using unified KV cache
|
|
77
|
+
// for simplicity, we always use sequential equal split for now
|
|
78
|
+
ubatch = balloc.split_equal(n_ubatch, true);
|
|
77
79
|
}
|
|
78
80
|
|
|
79
81
|
if (ubatch.n_tokens == 0) {
|
|
@@ -175,17 +177,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
|
|
|
175
177
|
}
|
|
176
178
|
|
|
177
179
|
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
mem_recr->state_write(io, seq_id);
|
|
180
|
+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
|
|
181
|
+
mem_attn->state_write(io, seq_id, flags);
|
|
182
|
+
}
|
|
183
|
+
mem_recr->state_write(io, seq_id, flags);
|
|
182
184
|
}
|
|
183
185
|
|
|
184
186
|
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
mem_recr->state_read(io, seq_id);
|
|
187
|
+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
|
|
188
|
+
mem_attn->state_read(io, seq_id, flags);
|
|
189
|
+
}
|
|
190
|
+
mem_recr->state_read(io, seq_id, flags);
|
|
189
191
|
}
|
|
190
192
|
|
|
191
193
|
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
|
|
@@ -220,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
|
220
222
|
ubatches(std::move(ubatches)),
|
|
221
223
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
222
224
|
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
|
223
|
-
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),
|
|
225
|
+
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
|
224
226
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
225
227
|
}
|
|
226
228
|
|