whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -7,11 +7,50 @@
|
|
|
7
7
|
#include "llama-kv-cache.h"
|
|
8
8
|
#include "llama-kv-cache-iswa.h"
|
|
9
9
|
#include "llama-memory-hybrid.h"
|
|
10
|
+
#include "llama-memory-hybrid-iswa.h"
|
|
10
11
|
#include "llama-memory-recurrent.h"
|
|
11
12
|
|
|
12
13
|
#include <cassert>
|
|
13
14
|
#include <cmath>
|
|
14
15
|
#include <cstring>
|
|
16
|
+
#include <numeric>
|
|
17
|
+
#include <sstream>
|
|
18
|
+
#include <unordered_set>
|
|
19
|
+
|
|
20
|
+
// dedup helpers
|
|
21
|
+
|
|
22
|
+
static ggml_tensor * build_kq_mask(
|
|
23
|
+
ggml_context * ctx,
|
|
24
|
+
const llama_kv_cache_context * mctx,
|
|
25
|
+
const llama_ubatch & ubatch,
|
|
26
|
+
const llama_cparams & cparams) {
|
|
27
|
+
const auto n_kv = mctx->get_n_kv();
|
|
28
|
+
const auto n_tokens = ubatch.n_tokens;
|
|
29
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
30
|
+
|
|
31
|
+
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
static bool can_reuse_kq_mask(
|
|
35
|
+
ggml_tensor * kq_mask,
|
|
36
|
+
const llama_kv_cache_context * mctx,
|
|
37
|
+
const llama_ubatch & ubatch,
|
|
38
|
+
const llama_cparams & cparams) {
|
|
39
|
+
const auto n_kv = mctx->get_n_kv();
|
|
40
|
+
const auto n_tokens = ubatch.n_tokens;
|
|
41
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
42
|
+
|
|
43
|
+
bool res = true;
|
|
44
|
+
|
|
45
|
+
res &= (kq_mask->ne[0] == n_kv);
|
|
46
|
+
res &= (kq_mask->ne[1] == n_tokens/n_stream);
|
|
47
|
+
res &= (kq_mask->ne[2] == 1);
|
|
48
|
+
res &= (kq_mask->ne[3] == n_stream);
|
|
49
|
+
|
|
50
|
+
return res;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// impl
|
|
15
54
|
|
|
16
55
|
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
17
56
|
if (ubatch->token) {
|
|
@@ -21,7 +60,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
21
60
|
}
|
|
22
61
|
|
|
23
62
|
if (ubatch->embd) {
|
|
24
|
-
|
|
63
|
+
GGML_ASSERT(n_embd == embd->ne[0]);
|
|
64
|
+
|
|
25
65
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
26
66
|
|
|
27
67
|
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
|
@@ -31,8 +71,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
31
71
|
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|
32
72
|
bool res = true;
|
|
33
73
|
|
|
34
|
-
res &= (!
|
|
35
|
-
res &= (!
|
|
74
|
+
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
|
75
|
+
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
|
36
76
|
|
|
37
77
|
return res;
|
|
38
78
|
}
|
|
@@ -62,7 +102,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
|
62
102
|
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
|
63
103
|
bool res = true;
|
|
64
104
|
|
|
65
|
-
res &= pos->ne[0] == params.ubatch.n_tokens;
|
|
105
|
+
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
|
|
66
106
|
|
|
67
107
|
return res;
|
|
68
108
|
}
|
|
@@ -71,11 +111,14 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|
|
71
111
|
if (ubatch->pos && attn_scale) {
|
|
72
112
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
73
113
|
|
|
114
|
+
GGML_ASSERT(f_attn_temp_scale != 0.0f);
|
|
115
|
+
GGML_ASSERT(n_attn_temp_floor_scale != 0);
|
|
116
|
+
|
|
74
117
|
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
|
75
118
|
for (int i = 0; i < n_tokens; ++i) {
|
|
76
119
|
const float pos = ubatch->pos[i];
|
|
77
120
|
attn_scale_data[i] = std::log(
|
|
78
|
-
std::floor((pos +
|
|
121
|
+
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
|
|
79
122
|
) * f_attn_temp_scale + 1.0;
|
|
80
123
|
}
|
|
81
124
|
|
|
@@ -92,11 +135,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
92
135
|
|
|
93
136
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
94
137
|
|
|
95
|
-
for (int
|
|
96
|
-
for (int
|
|
97
|
-
|
|
98
|
-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
99
|
-
}
|
|
138
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
139
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
140
|
+
data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
100
141
|
}
|
|
101
142
|
}
|
|
102
143
|
}
|
|
@@ -144,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
|
|
144
185
|
}
|
|
145
186
|
|
|
146
187
|
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
147
|
-
if (cparams.embeddings
|
|
188
|
+
if (cparams.embeddings &&
|
|
189
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
|
|
190
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
|
|
191
|
+
|
|
148
192
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
149
193
|
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
150
194
|
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
@@ -206,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
206
250
|
|
|
207
251
|
const bool last = (
|
|
208
252
|
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
|
209
|
-
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
|
253
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
|
|
210
254
|
);
|
|
211
255
|
|
|
212
256
|
for (int i = 0; i < n_tokens; ++i) {
|
|
@@ -251,6 +295,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
|
251
295
|
}
|
|
252
296
|
}
|
|
253
297
|
|
|
298
|
+
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
|
299
|
+
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
|
|
300
|
+
|
|
301
|
+
this->mctx = mctx;
|
|
302
|
+
|
|
303
|
+
bool res = true;
|
|
304
|
+
|
|
305
|
+
res &= s_copy->ne[0] == mctx->get_n_rs();
|
|
306
|
+
|
|
307
|
+
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
308
|
+
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
|
309
|
+
|
|
310
|
+
res &= head == mctx->get_head();
|
|
311
|
+
res &= rs_z == mctx->get_rs_z();
|
|
312
|
+
|
|
313
|
+
return res;
|
|
314
|
+
}
|
|
315
|
+
|
|
254
316
|
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
255
317
|
GGML_UNUSED(ubatch);
|
|
256
318
|
|
|
@@ -261,12 +323,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
261
323
|
}
|
|
262
324
|
}
|
|
263
325
|
|
|
264
|
-
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
|
326
|
+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
|
265
327
|
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
|
266
|
-
const char * swa_type_str =
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
328
|
+
const char * swa_type_str = "unknown";
|
|
329
|
+
|
|
330
|
+
switch (swa_type) {
|
|
331
|
+
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
|
|
332
|
+
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
|
|
333
|
+
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
|
|
334
|
+
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
|
335
|
+
};
|
|
336
|
+
|
|
270
337
|
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
|
271
338
|
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
|
272
339
|
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
|
@@ -295,50 +362,65 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
295
362
|
const int64_t n_kv = ubatch->n_tokens;
|
|
296
363
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
297
364
|
|
|
298
|
-
|
|
299
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
300
|
-
|
|
301
|
-
float * data = (float *) kq_mask->data;
|
|
302
|
-
|
|
303
|
-
// [TAG_NO_CACHE_ISWA]
|
|
304
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
|
305
|
-
|
|
306
|
-
for (int h = 0; h < 1; ++h) {
|
|
365
|
+
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
307
366
|
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
308
367
|
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
368
|
+
const llama_pos p1 = ubatch->pos[i1];
|
|
309
369
|
|
|
310
|
-
|
|
311
|
-
float f = -INFINITY;
|
|
312
|
-
|
|
313
|
-
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
|
314
|
-
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
370
|
+
const uint64_t idst = i1*n_kv;
|
|
315
371
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
372
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
373
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
374
|
+
const llama_pos p0 = ubatch->pos[i0];
|
|
319
375
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
376
|
+
// mask different sequences
|
|
377
|
+
if (s0 != s1) {
|
|
378
|
+
continue;
|
|
379
|
+
}
|
|
323
380
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
381
|
+
// mask future tokens
|
|
382
|
+
if (cparams.causal_attn && p0 > p1) {
|
|
383
|
+
continue;
|
|
384
|
+
}
|
|
328
385
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
} else {
|
|
333
|
-
f = 0.0f;
|
|
334
|
-
}
|
|
386
|
+
// apply SWA if any
|
|
387
|
+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
388
|
+
continue;
|
|
335
389
|
}
|
|
336
|
-
|
|
390
|
+
|
|
391
|
+
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
337
392
|
}
|
|
338
393
|
}
|
|
394
|
+
};
|
|
395
|
+
|
|
396
|
+
{
|
|
397
|
+
GGML_ASSERT(self_kq_mask);
|
|
398
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
|
399
|
+
|
|
400
|
+
float * data = (float *) self_kq_mask->data;
|
|
401
|
+
|
|
402
|
+
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
|
403
|
+
|
|
404
|
+
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
|
405
|
+
|
|
406
|
+
if (debug) {
|
|
407
|
+
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
|
408
|
+
}
|
|
339
409
|
}
|
|
340
|
-
|
|
341
|
-
|
|
410
|
+
|
|
411
|
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
412
|
+
GGML_ASSERT(self_kq_mask_swa);
|
|
413
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
|
414
|
+
|
|
415
|
+
float * data = (float *) self_kq_mask_swa->data;
|
|
416
|
+
|
|
417
|
+
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
|
418
|
+
|
|
419
|
+
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
|
420
|
+
|
|
421
|
+
if (debug) {
|
|
422
|
+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
|
423
|
+
}
|
|
342
424
|
}
|
|
343
425
|
}
|
|
344
426
|
|
|
@@ -359,8 +441,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
|
|
359
441
|
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
360
442
|
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
361
443
|
|
|
362
|
-
res &= self_kq_mask
|
|
363
|
-
|
|
444
|
+
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
|
445
|
+
|
|
446
|
+
return res;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
|
|
450
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
451
|
+
|
|
452
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
|
|
456
|
+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
|
457
|
+
|
|
458
|
+
this->mctx = mctx;
|
|
459
|
+
|
|
460
|
+
bool res = true;
|
|
461
|
+
|
|
462
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
463
|
+
|
|
464
|
+
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
|
364
465
|
|
|
365
466
|
return res;
|
|
366
467
|
}
|
|
@@ -390,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
|
|
390
491
|
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
|
391
492
|
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
392
493
|
|
|
393
|
-
res &= self_kq_mask
|
|
394
|
-
res &=
|
|
395
|
-
|
|
396
|
-
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
|
397
|
-
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
|
494
|
+
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
|
495
|
+
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
|
398
496
|
|
|
399
497
|
return res;
|
|
400
498
|
}
|
|
@@ -410,34 +508,212 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
410
508
|
|
|
411
509
|
float * data = (float *) cross_kq_mask->data;
|
|
412
510
|
|
|
413
|
-
for (int
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
511
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
512
|
+
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
|
513
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
514
|
+
float f = -INFINITY;
|
|
417
515
|
|
|
418
|
-
|
|
419
|
-
|
|
516
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
517
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
420
518
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
}
|
|
519
|
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
520
|
+
f = 0.0f;
|
|
424
521
|
}
|
|
425
|
-
|
|
426
|
-
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
|
427
522
|
}
|
|
428
|
-
}
|
|
429
523
|
|
|
430
|
-
|
|
431
|
-
for (int j = 0; j < n_enc; ++j) {
|
|
432
|
-
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
433
|
-
}
|
|
524
|
+
data[i*n_enc + j] = f;
|
|
434
525
|
}
|
|
435
526
|
}
|
|
436
527
|
}
|
|
437
528
|
|
|
438
529
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
439
|
-
inp_attn->
|
|
440
|
-
|
|
530
|
+
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
531
|
+
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
|
532
|
+
|
|
533
|
+
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
534
|
+
|
|
535
|
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
536
|
+
|
|
537
|
+
if (inp_rs->s_copy) {
|
|
538
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
539
|
+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
540
|
+
|
|
541
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
542
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
543
|
+
data[i] = mctx->get_recr()->s_copy(i);
|
|
544
|
+
}
|
|
545
|
+
}
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|
549
|
+
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
|
550
|
+
|
|
551
|
+
this->mctx = mctx;
|
|
552
|
+
|
|
553
|
+
bool res = true;
|
|
554
|
+
|
|
555
|
+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
556
|
+
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
557
|
+
|
|
558
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
|
559
|
+
|
|
560
|
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
561
|
+
|
|
562
|
+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
563
|
+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
564
|
+
|
|
565
|
+
res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
566
|
+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
567
|
+
|
|
568
|
+
return res;
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
// TODO: Hybrid input classes are a bit redundant.
|
|
572
|
+
// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
|
|
573
|
+
// Refactoring is required in the future.
|
|
574
|
+
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
|
|
575
|
+
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
576
|
+
|
|
577
|
+
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
578
|
+
|
|
579
|
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
580
|
+
|
|
581
|
+
if (inp_rs->s_copy) {
|
|
582
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
583
|
+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
584
|
+
|
|
585
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
586
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
587
|
+
data[i] = mctx->get_recr()->s_copy(i);
|
|
588
|
+
}
|
|
589
|
+
}
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
|
|
593
|
+
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
|
594
|
+
|
|
595
|
+
this->mctx = mctx;
|
|
596
|
+
|
|
597
|
+
bool res = true;
|
|
598
|
+
|
|
599
|
+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
600
|
+
|
|
601
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
|
602
|
+
|
|
603
|
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
604
|
+
|
|
605
|
+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
606
|
+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
607
|
+
|
|
608
|
+
res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
609
|
+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
610
|
+
|
|
611
|
+
return res;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
|
615
|
+
const auto * attn_ctx = mctx->get_attn();
|
|
616
|
+
|
|
617
|
+
// base tensors may not be allocated if there are no non-SWA attention layers
|
|
618
|
+
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
|
619
|
+
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
620
|
+
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
|
621
|
+
|
|
622
|
+
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
// swa tensors may not be allocated if there are no SWA attention layers
|
|
626
|
+
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
|
627
|
+
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
|
628
|
+
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
|
629
|
+
|
|
630
|
+
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
634
|
+
|
|
635
|
+
if (inp_rs->s_copy) {
|
|
636
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
637
|
+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
638
|
+
|
|
639
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
640
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
641
|
+
data[i] = mctx->get_recr()->s_copy(i);
|
|
642
|
+
}
|
|
643
|
+
}
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
|
|
647
|
+
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
|
|
648
|
+
|
|
649
|
+
this->mctx = mctx;
|
|
650
|
+
|
|
651
|
+
bool res = true;
|
|
652
|
+
|
|
653
|
+
const auto * attn_ctx = mctx->get_attn();
|
|
654
|
+
|
|
655
|
+
// base tensors may not be allocated if there are no non-SWA attention layers
|
|
656
|
+
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
|
657
|
+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
658
|
+
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
659
|
+
|
|
660
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
// swa tensors may not be allocated if there are no SWA attention layers
|
|
664
|
+
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
|
665
|
+
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
|
666
|
+
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
667
|
+
|
|
668
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
672
|
+
|
|
673
|
+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
674
|
+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
675
|
+
|
|
676
|
+
res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
677
|
+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
678
|
+
|
|
679
|
+
return res;
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
|
683
|
+
// set the inputs only for the active samplers in the current ubatch
|
|
684
|
+
std::unordered_set<llama_seq_id> active_samplers;
|
|
685
|
+
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
|
|
686
|
+
if (ubatch->output[i]) {
|
|
687
|
+
llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
688
|
+
active_samplers.insert(seq_id);
|
|
689
|
+
}
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
for (auto seq_id : active_samplers) {
|
|
693
|
+
if (samplers.find(seq_id) == samplers.end()) {
|
|
694
|
+
continue;
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
auto & sampler = samplers[seq_id];
|
|
698
|
+
|
|
699
|
+
if (sampler->iface->backend_set_input) {
|
|
700
|
+
sampler->iface->backend_set_input(sampler);
|
|
701
|
+
}
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
|
706
|
+
if (samplers.size() != params.samplers.size()) {
|
|
707
|
+
return false;
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
for (const auto & [seq_id, sampler] : params.samplers) {
|
|
711
|
+
if (samplers[seq_id] != sampler) {
|
|
712
|
+
return false;
|
|
713
|
+
}
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
return true;
|
|
441
717
|
}
|
|
442
718
|
|
|
443
719
|
//
|
|
@@ -456,10 +732,15 @@ int64_t llm_graph_result::get_max_nodes() const {
|
|
|
456
732
|
}
|
|
457
733
|
|
|
458
734
|
void llm_graph_result::reset() {
|
|
459
|
-
|
|
735
|
+
t_inp_tokens = nullptr;
|
|
736
|
+
t_inp_embd = nullptr;
|
|
460
737
|
t_logits = nullptr;
|
|
461
738
|
t_embd = nullptr;
|
|
462
739
|
t_embd_pooled = nullptr;
|
|
740
|
+
t_sampled.clear();
|
|
741
|
+
t_sampled_probs.clear();
|
|
742
|
+
t_sampled_logits.clear();
|
|
743
|
+
t_candidates.clear();
|
|
463
744
|
|
|
464
745
|
params = {};
|
|
465
746
|
|
|
@@ -484,6 +765,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
|
|
484
765
|
}
|
|
485
766
|
}
|
|
486
767
|
|
|
768
|
+
void llm_graph_result::set_outputs() {
|
|
769
|
+
if (t_logits != nullptr) {
|
|
770
|
+
ggml_set_output(t_logits);
|
|
771
|
+
}
|
|
772
|
+
if (t_embd != nullptr) {
|
|
773
|
+
ggml_set_output(t_embd);
|
|
774
|
+
}
|
|
775
|
+
if (t_embd_pooled != nullptr) {
|
|
776
|
+
ggml_set_output(t_embd_pooled);
|
|
777
|
+
}
|
|
778
|
+
for (auto & [seq_id, t] : t_sampled) {
|
|
779
|
+
if (t != nullptr) {
|
|
780
|
+
ggml_set_output(t);
|
|
781
|
+
}
|
|
782
|
+
}
|
|
783
|
+
for (auto & [seq_id, t] : t_sampled_probs) {
|
|
784
|
+
if (t != nullptr) {
|
|
785
|
+
ggml_set_output(t);
|
|
786
|
+
}
|
|
787
|
+
}
|
|
788
|
+
for (auto & [seq_id, t] : t_sampled_logits) {
|
|
789
|
+
if (t != nullptr) {
|
|
790
|
+
ggml_set_output(t);
|
|
791
|
+
}
|
|
792
|
+
}
|
|
793
|
+
for (auto & [seq_id, t] : t_candidates) {
|
|
794
|
+
if (t != nullptr) {
|
|
795
|
+
ggml_set_output(t);
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
}
|
|
799
|
+
|
|
487
800
|
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
|
488
801
|
if (!this->params.allow_reuse(params)) {
|
|
489
802
|
if (debug > 1) {
|
|
@@ -536,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
536
849
|
ubatch (params.ubatch),
|
|
537
850
|
n_embd (hparams.n_embd),
|
|
538
851
|
n_layer (hparams.n_layer),
|
|
539
|
-
n_rot (hparams.n_rot),
|
|
852
|
+
n_rot (hparams.n_rot()),
|
|
540
853
|
n_ctx (cparams.n_ctx),
|
|
541
854
|
n_head (hparams.n_head()),
|
|
542
855
|
n_head_kv (hparams.n_head_kv()),
|
|
543
|
-
n_embd_head_k (hparams.n_embd_head_k),
|
|
856
|
+
n_embd_head_k (hparams.n_embd_head_k()),
|
|
544
857
|
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
|
545
|
-
n_embd_head_v (hparams.n_embd_head_v),
|
|
858
|
+
n_embd_head_v (hparams.n_embd_head_v()),
|
|
546
859
|
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
|
547
860
|
n_expert (hparams.n_expert),
|
|
548
861
|
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
|
|
@@ -565,6 +878,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
565
878
|
loras (params.loras),
|
|
566
879
|
mctx (params.mctx),
|
|
567
880
|
cross (params.cross),
|
|
881
|
+
samplers (params.samplers),
|
|
568
882
|
cb_func (params.cb),
|
|
569
883
|
res (params.res),
|
|
570
884
|
ctx0 (res->get_ctx()),
|
|
@@ -586,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec(
|
|
|
586
900
|
|
|
587
901
|
ggml_tensor * llm_graph_context::build_lora_mm(
|
|
588
902
|
ggml_tensor * w,
|
|
589
|
-
ggml_tensor * cur
|
|
903
|
+
ggml_tensor * cur,
|
|
904
|
+
ggml_tensor * w_s) const {
|
|
590
905
|
ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
|
591
906
|
|
|
592
907
|
for (const auto & lora : *loras) {
|
|
@@ -607,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm(
|
|
|
607
922
|
res = ggml_add(ctx0, res, ab_cur);
|
|
608
923
|
}
|
|
609
924
|
|
|
925
|
+
if (w_s) {
|
|
926
|
+
res = ggml_mul(ctx0, res, w_s);
|
|
927
|
+
}
|
|
928
|
+
|
|
610
929
|
return res;
|
|
611
930
|
}
|
|
612
931
|
|
|
@@ -732,6 +1051,26 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
732
1051
|
switch (type_op) {
|
|
733
1052
|
case LLM_FFN_SILU:
|
|
734
1053
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
1054
|
+
// Step35: HF clamps gate (after SiLU) and up before multiplication
|
|
1055
|
+
if (arch == LLM_ARCH_STEP35 && il >= 0) {
|
|
1056
|
+
const float limit = hparams.swiglu_clamp_shexp[il];
|
|
1057
|
+
constexpr float eps = 1e-6f;
|
|
1058
|
+
if (limit > eps) {
|
|
1059
|
+
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
|
1060
|
+
cb(gate_act, "ffn_silu", il);
|
|
1061
|
+
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
|
1062
|
+
cb(gate_act, "ffn_silu_clamped", il);
|
|
1063
|
+
|
|
1064
|
+
tmp = ggml_clamp(ctx0, tmp, -limit, limit);
|
|
1065
|
+
cb(tmp, "ffn_up_clamped", il);
|
|
1066
|
+
|
|
1067
|
+
cur = ggml_mul(ctx0, gate_act, tmp);
|
|
1068
|
+
cb(cur, "ffn_swiglu_limited", il);
|
|
1069
|
+
type_gate = LLM_FFN_SEQ;
|
|
1070
|
+
break;
|
|
1071
|
+
}
|
|
1072
|
+
}
|
|
1073
|
+
|
|
735
1074
|
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
|
736
1075
|
cb(cur, "ffn_swiglu", il);
|
|
737
1076
|
type_gate = LLM_FFN_SEQ;
|
|
@@ -795,8 +1134,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
795
1134
|
|
|
796
1135
|
if (down) {
|
|
797
1136
|
cur = build_lora_mm(down, cur);
|
|
798
|
-
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
799
|
-
// GLM4 and
|
|
1137
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
|
1138
|
+
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
|
|
800
1139
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
801
1140
|
}
|
|
802
1141
|
}
|
|
@@ -828,11 +1167,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
828
1167
|
int64_t n_expert_used,
|
|
829
1168
|
llm_ffn_op_type type_op,
|
|
830
1169
|
bool norm_w,
|
|
831
|
-
bool scale_w,
|
|
832
1170
|
float w_scale,
|
|
833
1171
|
llama_expert_gating_func_type gating_op,
|
|
834
1172
|
int il,
|
|
835
|
-
ggml_tensor * probs_in
|
|
1173
|
+
ggml_tensor * probs_in,
|
|
1174
|
+
ggml_tensor * gate_up_exps,
|
|
1175
|
+
ggml_tensor * up_exps_s,
|
|
1176
|
+
ggml_tensor * gate_exps_s,
|
|
1177
|
+
ggml_tensor * down_exps_s) const {
|
|
836
1178
|
return build_moe_ffn(
|
|
837
1179
|
cur,
|
|
838
1180
|
gate_inp, /* gate_inp_b */ nullptr,
|
|
@@ -844,11 +1186,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
844
1186
|
n_expert_used,
|
|
845
1187
|
type_op,
|
|
846
1188
|
norm_w,
|
|
847
|
-
scale_w,
|
|
848
1189
|
w_scale,
|
|
849
1190
|
gating_op,
|
|
850
1191
|
il,
|
|
851
|
-
probs_in
|
|
1192
|
+
probs_in,
|
|
1193
|
+
gate_up_exps,
|
|
1194
|
+
/* gate_up_exps_b */ nullptr,
|
|
1195
|
+
up_exps_s,
|
|
1196
|
+
gate_exps_s,
|
|
1197
|
+
down_exps_s
|
|
852
1198
|
);
|
|
853
1199
|
}
|
|
854
1200
|
|
|
@@ -867,11 +1213,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
867
1213
|
int64_t n_expert_used,
|
|
868
1214
|
llm_ffn_op_type type_op,
|
|
869
1215
|
bool norm_w,
|
|
870
|
-
bool scale_w,
|
|
871
1216
|
float w_scale,
|
|
872
1217
|
llama_expert_gating_func_type gating_op,
|
|
873
1218
|
int il,
|
|
874
|
-
ggml_tensor * probs_in
|
|
1219
|
+
ggml_tensor * probs_in,
|
|
1220
|
+
ggml_tensor * gate_up_exps,
|
|
1221
|
+
ggml_tensor * gate_up_exps_b,
|
|
1222
|
+
ggml_tensor * up_exps_s,
|
|
1223
|
+
ggml_tensor * gate_exps_s,
|
|
1224
|
+
ggml_tensor * down_exps_s) const {
|
|
875
1225
|
const int64_t n_embd = cur->ne[0];
|
|
876
1226
|
const int64_t n_tokens = cur->ne[1];
|
|
877
1227
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
|
@@ -928,8 +1278,33 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
928
1278
|
cb(selection_probs, "ffn_moe_probs_biased", il);
|
|
929
1279
|
}
|
|
930
1280
|
|
|
1281
|
+
// select top n_group_used expert groups
|
|
1282
|
+
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
|
|
1283
|
+
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
|
|
1284
|
+
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
|
1285
|
+
|
|
1286
|
+
// organize experts into n_expert_groups
|
|
1287
|
+
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
|
1288
|
+
|
|
1289
|
+
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
|
1290
|
+
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
|
1291
|
+
|
|
1292
|
+
// get top n_group_used expert groups
|
|
1293
|
+
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
|
1294
|
+
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
|
1295
|
+
|
|
1296
|
+
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
|
1297
|
+
cb(expert_groups, "ffn_moe_group_topk", il);
|
|
1298
|
+
|
|
1299
|
+
// mask out the other groups
|
|
1300
|
+
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
|
1301
|
+
selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
|
1302
|
+
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
|
1303
|
+
cb(selection_probs, "ffn_moe_probs_masked", il);
|
|
1304
|
+
}
|
|
1305
|
+
|
|
931
1306
|
// select experts
|
|
932
|
-
ggml_tensor * selected_experts =
|
|
1307
|
+
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
|
933
1308
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
934
1309
|
cb(selected_experts, "ffn_moe_topk", il);
|
|
935
1310
|
|
|
@@ -959,12 +1334,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
959
1334
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
|
960
1335
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
|
961
1336
|
|
|
1337
|
+
// Avoid division by zero, clamp to smallest number representable by F16
|
|
1338
|
+
weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
|
|
1339
|
+
cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
|
|
1340
|
+
|
|
962
1341
|
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
|
963
1342
|
cb(weights, "ffn_moe_weights_norm", il);
|
|
964
1343
|
|
|
965
1344
|
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
|
966
1345
|
}
|
|
967
|
-
if (
|
|
1346
|
+
if (w_scale != 0.0f && w_scale != 1.0f) {
|
|
968
1347
|
weights = ggml_scale(ctx0, weights, w_scale);
|
|
969
1348
|
cb(weights, "ffn_moe_weights_scaled", il);
|
|
970
1349
|
}
|
|
@@ -981,30 +1360,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
981
1360
|
cb(cur, "ffn_moe_weighted", il);
|
|
982
1361
|
}
|
|
983
1362
|
|
|
984
|
-
ggml_tensor * up =
|
|
985
|
-
|
|
1363
|
+
ggml_tensor * up = nullptr;
|
|
1364
|
+
ggml_tensor * experts = nullptr;
|
|
986
1365
|
|
|
987
|
-
if (
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
1366
|
+
if (gate_up_exps) {
|
|
1367
|
+
// merged gate_up path: one mul_mat_id, then split into gate and up views
|
|
1368
|
+
ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
|
|
1369
|
+
cb(gate_up, "ffn_moe_gate_up", il);
|
|
991
1370
|
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
1371
|
+
if (gate_up_exps_b) {
|
|
1372
|
+
gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
|
|
1373
|
+
cb(gate_up, "ffn_moe_gate_up_biased", il);
|
|
1374
|
+
}
|
|
1375
|
+
|
|
1376
|
+
// apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
|
|
1377
|
+
if (up_exps_s) {
|
|
1378
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
|
|
1379
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1380
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1381
|
+
gate_up = ggml_mul(ctx0, gate_up, s);
|
|
1382
|
+
cb(gate_up, "ffn_moe_gate_up_scaled", il);
|
|
1383
|
+
}
|
|
1384
|
+
|
|
1385
|
+
const int64_t n_ff = gate_up->ne[0] / 2;
|
|
1386
|
+
cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
|
|
995
1387
|
cb(cur, "ffn_moe_gate", il);
|
|
1388
|
+
up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
|
|
1389
|
+
cb(up, "ffn_moe_up", il);
|
|
996
1390
|
} else {
|
|
997
|
-
|
|
998
|
-
|
|
1391
|
+
// separate gate and up path
|
|
1392
|
+
up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
1393
|
+
cb(up, "ffn_moe_up", il);
|
|
1394
|
+
|
|
1395
|
+
if (up_exps_b) {
|
|
1396
|
+
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
|
1397
|
+
cb(up, "ffn_moe_up_biased", il);
|
|
1398
|
+
}
|
|
1399
|
+
|
|
1400
|
+
// apply per-expert scale2 to up
|
|
1401
|
+
if (up_exps_s) {
|
|
1402
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
|
|
1403
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1404
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1405
|
+
up = ggml_mul(ctx0, up, s);
|
|
1406
|
+
cb(up, "ffn_moe_up_scaled", il);
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
if (gate_exps) {
|
|
1410
|
+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
1411
|
+
cb(cur, "ffn_moe_gate", il);
|
|
1412
|
+
} else {
|
|
1413
|
+
cur = up;
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
if (gate_exps_b) {
|
|
1417
|
+
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
|
1418
|
+
cb(cur, "ffn_moe_gate_biased", il);
|
|
1419
|
+
}
|
|
999
1420
|
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1421
|
+
// apply per-expert scale2 to gate
|
|
1422
|
+
if (gate_exps_s) {
|
|
1423
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
|
|
1424
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1425
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1426
|
+
cur = ggml_mul(ctx0, cur, s);
|
|
1427
|
+
cb(cur, "ffn_moe_gate_scaled", il);
|
|
1428
|
+
}
|
|
1003
1429
|
}
|
|
1004
1430
|
|
|
1431
|
+
const bool has_gate = gate_exps || gate_up_exps;
|
|
1432
|
+
|
|
1005
1433
|
switch (type_op) {
|
|
1006
1434
|
case LLM_FFN_SILU:
|
|
1007
1435
|
if (gate_exps) {
|
|
1436
|
+
// Step35: per-layer clamp for routed experts
|
|
1437
|
+
if (arch == LLM_ARCH_STEP35 && il >= 0) {
|
|
1438
|
+
const float limit = hparams.swiglu_clamp_exp[il];
|
|
1439
|
+
constexpr float eps = 1e-6f;
|
|
1440
|
+
if (limit > eps) {
|
|
1441
|
+
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
|
1442
|
+
cb(gate_act, "ffn_moe_silu", il);
|
|
1443
|
+
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
|
1444
|
+
cb(gate_act, "ffn_moe_silu_clamped", il);
|
|
1445
|
+
|
|
1446
|
+
up = ggml_clamp(ctx0, up, -limit, limit);
|
|
1447
|
+
cb(up, "ffn_moe_up_clamped", il);
|
|
1448
|
+
|
|
1449
|
+
cur = ggml_mul(ctx0, gate_act, up);
|
|
1450
|
+
cb(cur, "ffn_moe_swiglu_limited", il);
|
|
1451
|
+
break;
|
|
1452
|
+
}
|
|
1453
|
+
}
|
|
1454
|
+
}
|
|
1455
|
+
|
|
1456
|
+
if (has_gate) {
|
|
1008
1457
|
cur = ggml_swiglu_split(ctx0, cur, up);
|
|
1009
1458
|
cb(cur, "ffn_moe_swiglu", il);
|
|
1010
1459
|
} else {
|
|
@@ -1012,7 +1461,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1012
1461
|
cb(cur, "ffn_moe_silu", il);
|
|
1013
1462
|
} break;
|
|
1014
1463
|
case LLM_FFN_GELU:
|
|
1015
|
-
if (
|
|
1464
|
+
if (has_gate) {
|
|
1016
1465
|
cur = ggml_geglu_split(ctx0, cur, up);
|
|
1017
1466
|
cb(cur, "ffn_moe_geglu", il);
|
|
1018
1467
|
} else {
|
|
@@ -1028,13 +1477,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1028
1477
|
cb(cur, "ffn_moe_swiglu_oai", il);
|
|
1029
1478
|
} break;
|
|
1030
1479
|
case LLM_FFN_RELU:
|
|
1031
|
-
if (
|
|
1480
|
+
if (has_gate) {
|
|
1032
1481
|
cur = ggml_reglu_split(ctx0, cur, up);
|
|
1033
1482
|
cb(cur, "ffn_moe_reglu", il);
|
|
1034
1483
|
} else {
|
|
1035
1484
|
cur = ggml_relu(ctx0, cur);
|
|
1036
1485
|
cb(cur, "ffn_moe_relu", il);
|
|
1037
1486
|
} break;
|
|
1487
|
+
case LLM_FFN_RELU_SQR:
|
|
1488
|
+
if (has_gate) {
|
|
1489
|
+
// TODO: add support for gated squared relu
|
|
1490
|
+
GGML_ABORT("fatal error: gated squared relu not implemented");
|
|
1491
|
+
} else {
|
|
1492
|
+
cur = ggml_relu(ctx0, cur);
|
|
1493
|
+
cur = ggml_sqr(ctx0, cur);
|
|
1494
|
+
cb(cur, "ffn_moe_relu_sqr", il);
|
|
1495
|
+
} break;
|
|
1038
1496
|
default:
|
|
1039
1497
|
GGML_ABORT("fatal error");
|
|
1040
1498
|
}
|
|
@@ -1047,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1047
1505
|
cb(experts, "ffn_moe_down_biased", il);
|
|
1048
1506
|
}
|
|
1049
1507
|
|
|
1508
|
+
// apply per-expert scale2 to down
|
|
1509
|
+
if (down_exps_s) {
|
|
1510
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
|
|
1511
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1512
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1513
|
+
experts = ggml_mul(ctx0, experts, s);
|
|
1514
|
+
cb(experts, "ffn_moe_down_scaled", il);
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1050
1517
|
if (!weight_before_ffn) {
|
|
1051
1518
|
experts = ggml_mul(ctx0, experts, weights);
|
|
1052
1519
|
cb(cur, "ffn_moe_weighted", il);
|
|
@@ -1085,17 +1552,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1085
1552
|
|
|
1086
1553
|
// input embeddings with optional lora
|
|
1087
1554
|
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
1088
|
-
const int64_t
|
|
1555
|
+
const int64_t n_embd_inp = hparams.n_embd_inp();
|
|
1556
|
+
const int64_t n_embd = hparams.n_embd;
|
|
1557
|
+
|
|
1558
|
+
assert(n_embd_inp >= n_embd);
|
|
1559
|
+
|
|
1560
|
+
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
|
|
1089
1561
|
|
|
1090
|
-
|
|
1562
|
+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
|
1563
|
+
cb(inp->tokens, "inp_tokens", -1);
|
|
1564
|
+
ggml_set_input(inp->tokens);
|
|
1565
|
+
res->t_inp_tokens = inp->tokens;
|
|
1091
1566
|
|
|
1092
|
-
|
|
1567
|
+
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
|
|
1568
|
+
cb(inp->embd, "inp_embd", -1);
|
|
1569
|
+
ggml_set_input(inp->embd);
|
|
1093
1570
|
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1571
|
+
// select one of the 2 inputs, based on the batch contents
|
|
1572
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
|
1573
|
+
std::array<ggml_tensor *, 2> inps;
|
|
1574
|
+
|
|
1575
|
+
// token embeddings path (ubatch.token != nullptr)
|
|
1576
|
+
{
|
|
1577
|
+
auto & cur = inps[0];
|
|
1099
1578
|
|
|
1100
1579
|
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
|
1101
1580
|
|
|
@@ -1116,22 +1595,43 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
1116
1595
|
|
|
1117
1596
|
cur = ggml_add(ctx0, cur, inpL_delta);
|
|
1118
1597
|
}
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1598
|
+
|
|
1599
|
+
if (n_embd_inp != n_embd) {
|
|
1600
|
+
cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
|
|
1601
|
+
}
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
// vector embeddings path (ubatch.embd != nullptr)
|
|
1605
|
+
{
|
|
1606
|
+
auto & cur = inps[1];
|
|
1122
1607
|
|
|
1123
1608
|
cur = inp->embd;
|
|
1124
1609
|
}
|
|
1125
1610
|
|
|
1611
|
+
assert(ggml_are_same_shape (inps[0], inps[1]));
|
|
1612
|
+
assert(ggml_are_same_stride(inps[0], inps[1]));
|
|
1613
|
+
|
|
1614
|
+
ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
|
|
1615
|
+
|
|
1616
|
+
if (n_embd_inp != n_embd) {
|
|
1617
|
+
cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
|
|
1618
|
+
}
|
|
1619
|
+
|
|
1620
|
+
res->t_inp_embd = cur;
|
|
1621
|
+
|
|
1126
1622
|
// For Granite architecture
|
|
1127
1623
|
if (hparams.f_embedding_scale != 0.0f) {
|
|
1128
1624
|
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
|
|
1129
1625
|
}
|
|
1130
1626
|
|
|
1131
|
-
cb(cur, "
|
|
1627
|
+
cb(cur, "embd", -1);
|
|
1132
1628
|
|
|
1133
1629
|
res->add_input(std::move(inp));
|
|
1134
1630
|
|
|
1631
|
+
// make sure the produced embeddings are immediately materialized in the ggml graph
|
|
1632
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
|
|
1633
|
+
ggml_build_forward_expand(gf, cur);
|
|
1634
|
+
|
|
1135
1635
|
return cur;
|
|
1136
1636
|
}
|
|
1137
1637
|
|
|
@@ -1149,13 +1649,14 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|
|
1149
1649
|
}
|
|
1150
1650
|
|
|
1151
1651
|
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
1152
|
-
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
|
1652
|
+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
|
|
1153
1653
|
|
|
1154
1654
|
auto & cur = inp->attn_scale;
|
|
1155
1655
|
|
|
1156
1656
|
// this need to be 1x1xN for broadcasting
|
|
1157
1657
|
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
|
|
1158
1658
|
ggml_set_input(cur);
|
|
1659
|
+
ggml_set_name(cur, "attn_scale");
|
|
1159
1660
|
|
|
1160
1661
|
res->add_input(std::move(inp));
|
|
1161
1662
|
|
|
@@ -1165,7 +1666,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
|
1165
1666
|
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
|
1166
1667
|
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
|
1167
1668
|
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
|
1168
|
-
// features that require constant topology such as
|
|
1669
|
+
// features that require constant topology such as pipeline parallelism
|
|
1169
1670
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
|
1170
1671
|
//if (n_outputs < n_tokens) {
|
|
1171
1672
|
// return nullptr;
|
|
@@ -1222,8 +1723,8 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|
|
1222
1723
|
// return cur;
|
|
1223
1724
|
//}
|
|
1224
1725
|
|
|
1225
|
-
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.
|
|
1226
|
-
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc
|
|
1726
|
+
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
|
1727
|
+
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1227
1728
|
|
|
1228
1729
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
|
1229
1730
|
ggml_set_input(cur);
|
|
@@ -1299,12 +1800,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1299
1800
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
|
1300
1801
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
|
1301
1802
|
|
|
1302
|
-
const auto n_kv = k->ne[1];
|
|
1303
|
-
|
|
1304
1803
|
ggml_tensor * cur;
|
|
1305
1804
|
|
|
1306
|
-
|
|
1307
|
-
if (
|
|
1805
|
+
const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
|
|
1806
|
+
if (use_flash_attn) {
|
|
1308
1807
|
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
|
1309
1808
|
|
|
1310
1809
|
if (v_trans) {
|
|
@@ -1330,7 +1829,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1330
1829
|
if (v_mla) {
|
|
1331
1830
|
#if 0
|
|
1332
1831
|
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
|
1333
|
-
// However, the code is optimized for dimensions 0 and 1 being large, so this is
|
|
1832
|
+
// However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
|
|
1334
1833
|
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
|
1335
1834
|
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1336
1835
|
#else
|
|
@@ -1419,10 +1918,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
1419
1918
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
|
1420
1919
|
|
|
1421
1920
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
1422
|
-
inp->
|
|
1423
|
-
ggml_set_input(inp->
|
|
1921
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
|
1922
|
+
ggml_set_input(inp->self_kq_mask);
|
|
1424
1923
|
|
|
1425
|
-
inp->
|
|
1924
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1925
|
+
|
|
1926
|
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
1927
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
|
1928
|
+
ggml_set_input(inp->self_kq_mask_swa);
|
|
1929
|
+
|
|
1930
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1931
|
+
} else {
|
|
1932
|
+
inp->self_kq_mask_swa = nullptr;
|
|
1933
|
+
inp->self_kq_mask_swa_cnv = nullptr;
|
|
1934
|
+
}
|
|
1426
1935
|
|
|
1427
1936
|
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
|
1428
1937
|
}
|
|
@@ -1447,7 +1956,9 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1447
1956
|
ggml_build_forward_expand(gf, k_cur);
|
|
1448
1957
|
ggml_build_forward_expand(gf, v_cur);
|
|
1449
1958
|
|
|
1450
|
-
const
|
|
1959
|
+
const bool is_swa = hparams.is_swa(il);
|
|
1960
|
+
|
|
1961
|
+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
1451
1962
|
|
|
1452
1963
|
// [TAG_NO_CACHE_PAD]
|
|
1453
1964
|
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
|
@@ -1488,14 +1999,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|
|
1488
1999
|
{
|
|
1489
2000
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
|
1490
2001
|
|
|
1491
|
-
const auto n_kv = mctx_cur->get_n_kv();
|
|
1492
|
-
const auto n_tokens = ubatch.n_tokens;
|
|
1493
|
-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1494
|
-
|
|
1495
2002
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1496
2003
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1497
2004
|
|
|
1498
|
-
inp->self_kq_mask =
|
|
2005
|
+
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
|
2006
|
+
|
|
1499
2007
|
ggml_set_input(inp->self_kq_mask);
|
|
1500
2008
|
|
|
1501
2009
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1521,14 +2029,17 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1521
2029
|
ggml_tensor * v_cur,
|
|
1522
2030
|
ggml_tensor * kq_b,
|
|
1523
2031
|
ggml_tensor * sinks,
|
|
1524
|
-
ggml_tensor * v_mla,
|
|
2032
|
+
ggml_tensor * v_mla, // TODO: remove
|
|
1525
2033
|
float kq_scale,
|
|
1526
2034
|
int il) const {
|
|
2035
|
+
GGML_ASSERT(v_mla == nullptr);
|
|
2036
|
+
|
|
1527
2037
|
// these nodes are added to the graph together so that they are not reordered
|
|
1528
2038
|
// by doing so, the number of splits in the graph is reduced
|
|
2039
|
+
// expand k later to enable rope fusion which directly writes into k-v cache
|
|
1529
2040
|
ggml_build_forward_expand(gf, q_cur);
|
|
1530
|
-
ggml_build_forward_expand(gf, k_cur);
|
|
1531
2041
|
ggml_build_forward_expand(gf, v_cur);
|
|
2042
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
1532
2043
|
|
|
1533
2044
|
const auto * mctx_cur = inp->mctx;
|
|
1534
2045
|
|
|
@@ -1550,6 +2061,89 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1550
2061
|
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
1551
2062
|
cb(cur, "kqv_out", il);
|
|
1552
2063
|
|
|
2064
|
+
if (wo) {
|
|
2065
|
+
cur = build_lora_mm(wo, cur);
|
|
2066
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
|
2067
|
+
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
|
|
2068
|
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
2069
|
+
}
|
|
2070
|
+
}
|
|
2071
|
+
|
|
2072
|
+
if (wo_b) {
|
|
2073
|
+
cur = ggml_add(ctx0, cur, wo_b);
|
|
2074
|
+
}
|
|
2075
|
+
|
|
2076
|
+
return cur;
|
|
2077
|
+
}
|
|
2078
|
+
|
|
2079
|
+
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|
2080
|
+
ggml_context * ctx0,
|
|
2081
|
+
const llama_ubatch & ubatch,
|
|
2082
|
+
const llama_hparams & hparams,
|
|
2083
|
+
const llama_cparams & cparams,
|
|
2084
|
+
const llama_kv_cache_context * mctx_cur) {
|
|
2085
|
+
|
|
2086
|
+
auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
|
|
2087
|
+
|
|
2088
|
+
{
|
|
2089
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
|
2090
|
+
|
|
2091
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
2092
|
+
|
|
2093
|
+
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
|
2094
|
+
ggml_set_input(inp->self_kq_mask);
|
|
2095
|
+
|
|
2096
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
2097
|
+
}
|
|
2098
|
+
|
|
2099
|
+
return inp;
|
|
2100
|
+
}
|
|
2101
|
+
|
|
2102
|
+
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
|
|
2103
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
|
2104
|
+
|
|
2105
|
+
auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
|
2106
|
+
|
|
2107
|
+
return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
|
|
2108
|
+
}
|
|
2109
|
+
|
|
2110
|
+
ggml_tensor * llm_graph_context::build_attn(
|
|
2111
|
+
llm_graph_input_attn_k * inp,
|
|
2112
|
+
ggml_tensor * wo,
|
|
2113
|
+
ggml_tensor * wo_b,
|
|
2114
|
+
ggml_tensor * q_cur,
|
|
2115
|
+
ggml_tensor * k_cur,
|
|
2116
|
+
ggml_tensor * v_cur,
|
|
2117
|
+
ggml_tensor * kq_b,
|
|
2118
|
+
ggml_tensor * sinks,
|
|
2119
|
+
ggml_tensor * v_mla,
|
|
2120
|
+
float kq_scale,
|
|
2121
|
+
int il) const {
|
|
2122
|
+
// these nodes are added to the graph together so that they are not reordered
|
|
2123
|
+
// by doing so, the number of splits in the graph is reduced
|
|
2124
|
+
// expand k later to enable rope fusion which directly writes into k-v cache
|
|
2125
|
+
ggml_build_forward_expand(gf, q_cur);
|
|
2126
|
+
ggml_build_forward_expand(gf, v_cur);
|
|
2127
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
2128
|
+
|
|
2129
|
+
const auto * mctx_cur = inp->mctx;
|
|
2130
|
+
|
|
2131
|
+
// store to KV cache
|
|
2132
|
+
{
|
|
2133
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
2134
|
+
|
|
2135
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
2136
|
+
}
|
|
2137
|
+
|
|
2138
|
+
const auto & kq_mask = inp->get_kq_mask();
|
|
2139
|
+
|
|
2140
|
+
ggml_tensor * q = q_cur;
|
|
2141
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
2142
|
+
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
|
2143
|
+
|
|
2144
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
2145
|
+
cb(cur, "kqv_out", il);
|
|
2146
|
+
|
|
1553
2147
|
if (wo) {
|
|
1554
2148
|
cur = build_lora_mm(wo, cur);
|
|
1555
2149
|
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
@@ -1637,7 +2231,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1637
2231
|
|
|
1638
2232
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1639
2233
|
|
|
1640
|
-
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc,
|
|
2234
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
|
|
1641
2235
|
ggml_set_input(inp->cross_kq_mask);
|
|
1642
2236
|
|
|
1643
2237
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
@@ -1695,32 +2289,30 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|
|
1695
2289
|
|
|
1696
2290
|
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
|
1697
2291
|
|
|
1698
|
-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1699
|
-
|
|
1700
2292
|
{
|
|
1701
|
-
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1702
|
-
|
|
1703
2293
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1704
2294
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1705
2295
|
|
|
1706
|
-
inp->self_kq_mask =
|
|
2296
|
+
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
|
1707
2297
|
ggml_set_input(inp->self_kq_mask);
|
|
2298
|
+
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
|
1708
2299
|
|
|
1709
2300
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
2301
|
+
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
|
1710
2302
|
}
|
|
1711
2303
|
|
|
1712
2304
|
{
|
|
1713
2305
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
|
1714
2306
|
|
|
1715
|
-
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1716
|
-
|
|
1717
2307
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1718
2308
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1719
2309
|
|
|
1720
|
-
inp->self_kq_mask_swa =
|
|
2310
|
+
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
|
1721
2311
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
2312
|
+
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
|
1722
2313
|
|
|
1723
2314
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
2315
|
+
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
|
1724
2316
|
}
|
|
1725
2317
|
|
|
1726
2318
|
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
|
@@ -1777,6 +2369,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|
|
1777
2369
|
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
|
1778
2370
|
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
|
1779
2371
|
|
|
2372
|
+
inp->head = mctx_cur->get_head();
|
|
2373
|
+
inp->rs_z = mctx_cur->get_rs_z();
|
|
2374
|
+
|
|
1780
2375
|
return inp;
|
|
1781
2376
|
}
|
|
1782
2377
|
|
|
@@ -1845,19 +2440,91 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1845
2440
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
1846
2441
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
1847
2442
|
|
|
1848
|
-
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
|
2443
|
+
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
|
1849
2444
|
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
1850
2445
|
|
|
1851
|
-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
2446
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
1852
2447
|
|
|
1853
2448
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
1854
2449
|
}
|
|
1855
2450
|
|
|
2451
|
+
llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
|
|
2452
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
2453
|
+
|
|
2454
|
+
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
|
2455
|
+
auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
2456
|
+
|
|
2457
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
2458
|
+
|
|
2459
|
+
return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
|
|
2460
|
+
}
|
|
2461
|
+
|
|
2462
|
+
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
|
|
2463
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
|
|
2464
|
+
|
|
2465
|
+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
|
2466
|
+
|
|
2467
|
+
// build iswa attention input
|
|
2468
|
+
const auto * attn_ctx = mctx_cur->get_attn();
|
|
2469
|
+
|
|
2470
|
+
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
|
2471
|
+
|
|
2472
|
+
{
|
|
2473
|
+
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
2474
|
+
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
2475
|
+
|
|
2476
|
+
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
|
2477
|
+
ggml_set_input(inp_attn->self_kq_mask);
|
|
2478
|
+
|
|
2479
|
+
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
|
2480
|
+
}
|
|
2481
|
+
|
|
2482
|
+
{
|
|
2483
|
+
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
2484
|
+
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
2485
|
+
|
|
2486
|
+
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
|
2487
|
+
ggml_set_input(inp_attn->self_kq_mask_swa);
|
|
2488
|
+
|
|
2489
|
+
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
2493
|
+
|
|
2494
|
+
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
|
|
2495
|
+
}
|
|
2496
|
+
|
|
2497
|
+
void llm_graph_context::build_dense_out(
|
|
2498
|
+
ggml_tensor * dense_2,
|
|
2499
|
+
ggml_tensor * dense_2_b,
|
|
2500
|
+
ggml_tensor * dense_3) const {
|
|
2501
|
+
if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
|
|
2502
|
+
return;
|
|
2503
|
+
}
|
|
2504
|
+
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
|
2505
|
+
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
|
|
2506
|
+
|
|
2507
|
+
if (dense_2) {
|
|
2508
|
+
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
|
2509
|
+
}
|
|
2510
|
+
if (dense_2_b) {
|
|
2511
|
+
cur = ggml_add(ctx0, cur, dense_2_b);
|
|
2512
|
+
}
|
|
2513
|
+
if (dense_3) {
|
|
2514
|
+
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
|
2515
|
+
}
|
|
2516
|
+
cb(cur, "result_embd_pooled", -1);
|
|
2517
|
+
res->t_embd_pooled = cur;
|
|
2518
|
+
ggml_build_forward_expand(gf, cur);
|
|
2519
|
+
}
|
|
2520
|
+
|
|
2521
|
+
|
|
1856
2522
|
void llm_graph_context::build_pooling(
|
|
1857
2523
|
ggml_tensor * cls,
|
|
1858
2524
|
ggml_tensor * cls_b,
|
|
1859
2525
|
ggml_tensor * cls_out,
|
|
1860
|
-
ggml_tensor * cls_out_b
|
|
2526
|
+
ggml_tensor * cls_out_b,
|
|
2527
|
+
ggml_tensor * cls_norm) const {
|
|
1861
2528
|
if (!cparams.embeddings) {
|
|
1862
2529
|
return;
|
|
1863
2530
|
}
|
|
@@ -1896,8 +2563,15 @@ void llm_graph_context::build_pooling(
|
|
|
1896
2563
|
} break;
|
|
1897
2564
|
case LLAMA_POOLING_TYPE_RANK:
|
|
1898
2565
|
{
|
|
1899
|
-
|
|
1900
|
-
|
|
2566
|
+
if (arch == LLM_ARCH_MODERN_BERT) {
|
|
2567
|
+
// modern bert gte reranker builds mean first then applies prediction head and classifier
|
|
2568
|
+
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
|
|
2569
|
+
ggml_tensor * inp_mean = build_inp_mean();
|
|
2570
|
+
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
|
2571
|
+
} else {
|
|
2572
|
+
ggml_tensor * inp_cls = build_inp_cls();
|
|
2573
|
+
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
|
2574
|
+
}
|
|
1901
2575
|
|
|
1902
2576
|
// classification head
|
|
1903
2577
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
@@ -1906,7 +2580,15 @@ void llm_graph_context::build_pooling(
|
|
|
1906
2580
|
if (cls_b) {
|
|
1907
2581
|
cur = ggml_add(ctx0, cur, cls_b);
|
|
1908
2582
|
}
|
|
1909
|
-
|
|
2583
|
+
if (arch == LLM_ARCH_MODERN_BERT) {
|
|
2584
|
+
cur = ggml_gelu(ctx0, cur);
|
|
2585
|
+
} else {
|
|
2586
|
+
cur = ggml_tanh(ctx0, cur);
|
|
2587
|
+
}
|
|
2588
|
+
if (cls_norm) {
|
|
2589
|
+
// head norm
|
|
2590
|
+
cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
|
|
2591
|
+
}
|
|
1910
2592
|
}
|
|
1911
2593
|
|
|
1912
2594
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
@@ -1921,7 +2603,7 @@ void llm_graph_context::build_pooling(
|
|
|
1921
2603
|
}
|
|
1922
2604
|
|
|
1923
2605
|
// softmax for qwen3 reranker
|
|
1924
|
-
if (arch == LLM_ARCH_QWEN3) {
|
|
2606
|
+
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
|
|
1925
2607
|
cur = ggml_soft_max(ctx0, cur);
|
|
1926
2608
|
}
|
|
1927
2609
|
} break;
|
|
@@ -1937,6 +2619,94 @@ void llm_graph_context::build_pooling(
|
|
|
1937
2619
|
ggml_build_forward_expand(gf, cur);
|
|
1938
2620
|
}
|
|
1939
2621
|
|
|
2622
|
+
void llm_graph_context::build_sampling() const {
|
|
2623
|
+
if (samplers.empty() || !res->t_logits) {
|
|
2624
|
+
return;
|
|
2625
|
+
}
|
|
2626
|
+
|
|
2627
|
+
std::array<ggml_tensor *, 2> outs;
|
|
2628
|
+
outs[0] = res->t_logits;
|
|
2629
|
+
|
|
2630
|
+
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
|
2631
|
+
res->add_input(std::move(inp_sampling));
|
|
2632
|
+
|
|
2633
|
+
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
|
2634
|
+
int32_t logit_row_idx = 0;
|
|
2635
|
+
|
|
2636
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
2637
|
+
if (ubatch.output[i]) {
|
|
2638
|
+
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
2639
|
+
seq_to_logit_row[seq_id] = logit_row_idx;
|
|
2640
|
+
logit_row_idx++;
|
|
2641
|
+
}
|
|
2642
|
+
}
|
|
2643
|
+
|
|
2644
|
+
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
|
|
2645
|
+
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
|
2646
|
+
|
|
2647
|
+
// add a dummy row of logits
|
|
2648
|
+
// this trick makes the graph static, regardless of which samplers are activated
|
|
2649
|
+
// this is important in order to minimize graph reallocations
|
|
2650
|
+
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
|
2651
|
+
|
|
2652
|
+
for (const auto & [seq_id, sampler] : samplers) {
|
|
2653
|
+
const auto it = seq_to_logit_row.find(seq_id);
|
|
2654
|
+
|
|
2655
|
+
// inactive samplers always work on the first row
|
|
2656
|
+
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
|
|
2657
|
+
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
|
|
2658
|
+
|
|
2659
|
+
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
|
2660
|
+
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
|
2661
|
+
|
|
2662
|
+
struct llama_sampler_data data = {
|
|
2663
|
+
/*.logits =*/ logits_seq,
|
|
2664
|
+
/*.probs =*/ nullptr,
|
|
2665
|
+
/*.sampled =*/ nullptr,
|
|
2666
|
+
/*.candidates =*/ nullptr,
|
|
2667
|
+
};
|
|
2668
|
+
|
|
2669
|
+
assert(sampler->iface->backend_apply);
|
|
2670
|
+
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
|
2671
|
+
|
|
2672
|
+
if (data.sampled != nullptr) {
|
|
2673
|
+
res->t_sampled[seq_id] = data.sampled;
|
|
2674
|
+
outs[1] = data.sampled;
|
|
2675
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2676
|
+
}
|
|
2677
|
+
|
|
2678
|
+
if (data.probs != nullptr) {
|
|
2679
|
+
res->t_sampled_probs[seq_id] = data.probs;
|
|
2680
|
+
outs[1] = data.probs;
|
|
2681
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2682
|
+
}
|
|
2683
|
+
|
|
2684
|
+
if (data.logits != nullptr) {
|
|
2685
|
+
res->t_sampled_logits[seq_id] = data.logits;
|
|
2686
|
+
outs[1] = data.logits;
|
|
2687
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2688
|
+
}
|
|
2689
|
+
|
|
2690
|
+
if (data.candidates != nullptr) {
|
|
2691
|
+
res->t_candidates[seq_id] = data.candidates;
|
|
2692
|
+
outs[1] = data.candidates;
|
|
2693
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2694
|
+
}
|
|
2695
|
+
}
|
|
2696
|
+
|
|
2697
|
+
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
|
2698
|
+
/*
|
|
2699
|
+
for (const auto & [seq_id, sampler] : samplers) {
|
|
2700
|
+
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
|
2701
|
+
ggml_tensor * selected_token = it->second;
|
|
2702
|
+
if (selected_token != nullptr) {
|
|
2703
|
+
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
|
2704
|
+
}
|
|
2705
|
+
}
|
|
2706
|
+
}
|
|
2707
|
+
*/
|
|
2708
|
+
}
|
|
2709
|
+
|
|
1940
2710
|
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
|
1941
2711
|
// TODO move to hparams if a T5 variant appears that uses a different value
|
|
1942
2712
|
const int64_t max_distance = 128;
|
|
@@ -1952,7 +2722,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
|
|
|
1952
2722
|
|
|
1953
2723
|
if (bidirectional) {
|
|
1954
2724
|
relative_bucket += (relative_position > 0) * n_buckets;
|
|
1955
|
-
relative_position = abs(relative_position);
|
|
2725
|
+
relative_position = std::abs(relative_position);
|
|
1956
2726
|
} else {
|
|
1957
2727
|
relative_position = -std::min<int32_t>(relative_position, 0);
|
|
1958
2728
|
}
|