whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -7,13 +7,51 @@
|
|
|
7
7
|
#include "llama-kv-cache.h"
|
|
8
8
|
#include "llama-kv-cache-iswa.h"
|
|
9
9
|
#include "llama-memory-hybrid.h"
|
|
10
|
+
#include "llama-memory-hybrid-iswa.h"
|
|
10
11
|
#include "llama-memory-recurrent.h"
|
|
11
12
|
|
|
12
13
|
#include <cassert>
|
|
13
14
|
#include <cmath>
|
|
14
15
|
#include <cstring>
|
|
16
|
+
#include <numeric>
|
|
17
|
+
#include <sstream>
|
|
15
18
|
#include <unordered_set>
|
|
16
19
|
|
|
20
|
+
// dedup helpers
|
|
21
|
+
|
|
22
|
+
static ggml_tensor * build_kq_mask(
|
|
23
|
+
ggml_context * ctx,
|
|
24
|
+
const llama_kv_cache_context * mctx,
|
|
25
|
+
const llama_ubatch & ubatch,
|
|
26
|
+
const llama_cparams & cparams) {
|
|
27
|
+
const auto n_kv = mctx->get_n_kv();
|
|
28
|
+
const auto n_tokens = ubatch.n_tokens;
|
|
29
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
30
|
+
|
|
31
|
+
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
static bool can_reuse_kq_mask(
|
|
35
|
+
ggml_tensor * kq_mask,
|
|
36
|
+
const llama_kv_cache_context * mctx,
|
|
37
|
+
const llama_ubatch & ubatch,
|
|
38
|
+
const llama_cparams & cparams) {
|
|
39
|
+
const auto n_kv = mctx->get_n_kv();
|
|
40
|
+
const auto n_tokens = ubatch.n_tokens;
|
|
41
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
42
|
+
|
|
43
|
+
bool res = true;
|
|
44
|
+
|
|
45
|
+
res &= (kq_mask->ne[0] == n_kv);
|
|
46
|
+
res &= (kq_mask->ne[1] == n_tokens/n_stream);
|
|
47
|
+
res &= (kq_mask->ne[2] == 1);
|
|
48
|
+
res &= (kq_mask->ne[3] == n_stream);
|
|
49
|
+
|
|
50
|
+
return res;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// impl
|
|
54
|
+
|
|
17
55
|
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
18
56
|
if (ubatch->token) {
|
|
19
57
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -22,7 +60,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
22
60
|
}
|
|
23
61
|
|
|
24
62
|
if (ubatch->embd) {
|
|
25
|
-
|
|
63
|
+
GGML_ASSERT(n_embd == embd->ne[0]);
|
|
64
|
+
|
|
26
65
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
27
66
|
|
|
28
67
|
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
|
@@ -32,8 +71,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
32
71
|
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|
33
72
|
bool res = true;
|
|
34
73
|
|
|
35
|
-
res &= (!
|
|
36
|
-
res &= (!
|
|
74
|
+
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
|
75
|
+
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
|
37
76
|
|
|
38
77
|
return res;
|
|
39
78
|
}
|
|
@@ -96,11 +135,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
96
135
|
|
|
97
136
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
98
137
|
|
|
99
|
-
for (int
|
|
100
|
-
for (int
|
|
101
|
-
|
|
102
|
-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
103
|
-
}
|
|
138
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
139
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
140
|
+
data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
104
141
|
}
|
|
105
142
|
}
|
|
106
143
|
}
|
|
@@ -148,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
|
|
148
185
|
}
|
|
149
186
|
|
|
150
187
|
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
151
|
-
if (cparams.embeddings
|
|
188
|
+
if (cparams.embeddings &&
|
|
189
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
|
|
190
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
|
|
191
|
+
|
|
152
192
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
153
193
|
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
154
194
|
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
@@ -210,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
210
250
|
|
|
211
251
|
const bool last = (
|
|
212
252
|
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
|
213
|
-
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
|
253
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
|
|
214
254
|
);
|
|
215
255
|
|
|
216
256
|
for (int i = 0; i < n_tokens; ++i) {
|
|
@@ -323,34 +363,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
323
363
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
324
364
|
|
|
325
365
|
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
326
|
-
for (int
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
const llama_pos p1 = ubatch->pos[i1];
|
|
366
|
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
367
|
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
368
|
+
const llama_pos p1 = ubatch->pos[i1];
|
|
330
369
|
|
|
331
|
-
|
|
370
|
+
const uint64_t idst = i1*n_kv;
|
|
332
371
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
372
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
373
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
374
|
+
const llama_pos p0 = ubatch->pos[i0];
|
|
336
375
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
// mask future tokens
|
|
343
|
-
if (cparams.causal_attn && p0 > p1) {
|
|
344
|
-
continue;
|
|
345
|
-
}
|
|
376
|
+
// mask different sequences
|
|
377
|
+
if (s0 != s1) {
|
|
378
|
+
continue;
|
|
379
|
+
}
|
|
346
380
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
381
|
+
// mask future tokens
|
|
382
|
+
if (cparams.causal_attn && p0 > p1) {
|
|
383
|
+
continue;
|
|
384
|
+
}
|
|
351
385
|
|
|
352
|
-
|
|
386
|
+
// apply SWA if any
|
|
387
|
+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
388
|
+
continue;
|
|
353
389
|
}
|
|
390
|
+
|
|
391
|
+
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
354
392
|
}
|
|
355
393
|
}
|
|
356
394
|
};
|
|
@@ -403,8 +441,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
|
|
403
441
|
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
404
442
|
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
405
443
|
|
|
406
|
-
res &= self_kq_mask
|
|
407
|
-
|
|
444
|
+
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
|
445
|
+
|
|
446
|
+
return res;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
|
|
450
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
451
|
+
|
|
452
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
|
|
456
|
+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
|
457
|
+
|
|
458
|
+
this->mctx = mctx;
|
|
459
|
+
|
|
460
|
+
bool res = true;
|
|
461
|
+
|
|
462
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
463
|
+
|
|
464
|
+
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
|
408
465
|
|
|
409
466
|
return res;
|
|
410
467
|
}
|
|
@@ -434,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
|
|
434
491
|
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
|
435
492
|
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
436
493
|
|
|
437
|
-
res &= self_kq_mask
|
|
438
|
-
res &=
|
|
439
|
-
|
|
440
|
-
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
|
441
|
-
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
|
494
|
+
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
|
495
|
+
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
|
442
496
|
|
|
443
497
|
return res;
|
|
444
498
|
}
|
|
@@ -454,27 +508,20 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
454
508
|
|
|
455
509
|
float * data = (float *) cross_kq_mask->data;
|
|
456
510
|
|
|
457
|
-
for (int
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
511
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
512
|
+
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
|
513
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
514
|
+
float f = -INFINITY;
|
|
461
515
|
|
|
462
|
-
|
|
463
|
-
|
|
516
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
517
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
464
518
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
}
|
|
519
|
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
520
|
+
f = 0.0f;
|
|
468
521
|
}
|
|
469
|
-
|
|
470
|
-
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
|
471
522
|
}
|
|
472
|
-
}
|
|
473
523
|
|
|
474
|
-
|
|
475
|
-
for (int j = 0; j < n_enc; ++j) {
|
|
476
|
-
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
477
|
-
}
|
|
524
|
+
data[i*n_enc + j] = f;
|
|
478
525
|
}
|
|
479
526
|
}
|
|
480
527
|
}
|
|
@@ -508,8 +555,118 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|
|
508
555
|
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
509
556
|
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
510
557
|
|
|
511
|
-
res &= inp_attn->self_kq_mask
|
|
512
|
-
|
|
558
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
|
559
|
+
|
|
560
|
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
561
|
+
|
|
562
|
+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
563
|
+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
564
|
+
|
|
565
|
+
res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
566
|
+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
567
|
+
|
|
568
|
+
return res;
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
// TODO: Hybrid input classes are a bit redundant.
|
|
572
|
+
// Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
|
|
573
|
+
// Refactoring is required in the future.
|
|
574
|
+
void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
|
|
575
|
+
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
576
|
+
|
|
577
|
+
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
578
|
+
|
|
579
|
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
580
|
+
|
|
581
|
+
if (inp_rs->s_copy) {
|
|
582
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
583
|
+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
584
|
+
|
|
585
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
586
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
587
|
+
data[i] = mctx->get_recr()->s_copy(i);
|
|
588
|
+
}
|
|
589
|
+
}
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
|
|
593
|
+
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
|
594
|
+
|
|
595
|
+
this->mctx = mctx;
|
|
596
|
+
|
|
597
|
+
bool res = true;
|
|
598
|
+
|
|
599
|
+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
600
|
+
|
|
601
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
|
602
|
+
|
|
603
|
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
604
|
+
|
|
605
|
+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
606
|
+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
607
|
+
|
|
608
|
+
res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
609
|
+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
610
|
+
|
|
611
|
+
return res;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
|
615
|
+
const auto * attn_ctx = mctx->get_attn();
|
|
616
|
+
|
|
617
|
+
// base tensors may not be allocated if there are no non-SWA attention layers
|
|
618
|
+
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
|
619
|
+
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
620
|
+
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
|
621
|
+
|
|
622
|
+
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
// swa tensors may not be allocated if there are no SWA attention layers
|
|
626
|
+
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
|
627
|
+
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
|
628
|
+
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
|
629
|
+
|
|
630
|
+
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
634
|
+
|
|
635
|
+
if (inp_rs->s_copy) {
|
|
636
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
637
|
+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
638
|
+
|
|
639
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
640
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
641
|
+
data[i] = mctx->get_recr()->s_copy(i);
|
|
642
|
+
}
|
|
643
|
+
}
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
|
|
647
|
+
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
|
|
648
|
+
|
|
649
|
+
this->mctx = mctx;
|
|
650
|
+
|
|
651
|
+
bool res = true;
|
|
652
|
+
|
|
653
|
+
const auto * attn_ctx = mctx->get_attn();
|
|
654
|
+
|
|
655
|
+
// base tensors may not be allocated if there are no non-SWA attention layers
|
|
656
|
+
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
|
657
|
+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
658
|
+
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
659
|
+
|
|
660
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
// swa tensors may not be allocated if there are no SWA attention layers
|
|
664
|
+
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
|
665
|
+
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
|
666
|
+
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
667
|
+
|
|
668
|
+
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
|
669
|
+
}
|
|
513
670
|
|
|
514
671
|
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
515
672
|
|
|
@@ -575,7 +732,8 @@ int64_t llm_graph_result::get_max_nodes() const {
|
|
|
575
732
|
}
|
|
576
733
|
|
|
577
734
|
void llm_graph_result::reset() {
|
|
578
|
-
|
|
735
|
+
t_inp_tokens = nullptr;
|
|
736
|
+
t_inp_embd = nullptr;
|
|
579
737
|
t_logits = nullptr;
|
|
580
738
|
t_embd = nullptr;
|
|
581
739
|
t_embd_pooled = nullptr;
|
|
@@ -691,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
691
849
|
ubatch (params.ubatch),
|
|
692
850
|
n_embd (hparams.n_embd),
|
|
693
851
|
n_layer (hparams.n_layer),
|
|
694
|
-
n_rot (hparams.n_rot),
|
|
852
|
+
n_rot (hparams.n_rot()),
|
|
695
853
|
n_ctx (cparams.n_ctx),
|
|
696
854
|
n_head (hparams.n_head()),
|
|
697
855
|
n_head_kv (hparams.n_head_kv()),
|
|
698
|
-
n_embd_head_k (hparams.n_embd_head_k),
|
|
856
|
+
n_embd_head_k (hparams.n_embd_head_k()),
|
|
699
857
|
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
|
700
|
-
n_embd_head_v (hparams.n_embd_head_v),
|
|
858
|
+
n_embd_head_v (hparams.n_embd_head_v()),
|
|
701
859
|
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
|
702
860
|
n_expert (hparams.n_expert),
|
|
703
861
|
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
|
|
@@ -742,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec(
|
|
|
742
900
|
|
|
743
901
|
ggml_tensor * llm_graph_context::build_lora_mm(
|
|
744
902
|
ggml_tensor * w,
|
|
745
|
-
ggml_tensor * cur
|
|
903
|
+
ggml_tensor * cur,
|
|
904
|
+
ggml_tensor * w_s) const {
|
|
746
905
|
ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
|
747
906
|
|
|
748
907
|
for (const auto & lora : *loras) {
|
|
@@ -763,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm(
|
|
|
763
922
|
res = ggml_add(ctx0, res, ab_cur);
|
|
764
923
|
}
|
|
765
924
|
|
|
925
|
+
if (w_s) {
|
|
926
|
+
res = ggml_mul(ctx0, res, w_s);
|
|
927
|
+
}
|
|
928
|
+
|
|
766
929
|
return res;
|
|
767
930
|
}
|
|
768
931
|
|
|
@@ -888,6 +1051,26 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
888
1051
|
switch (type_op) {
|
|
889
1052
|
case LLM_FFN_SILU:
|
|
890
1053
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
1054
|
+
// Step35: HF clamps gate (after SiLU) and up before multiplication
|
|
1055
|
+
if (arch == LLM_ARCH_STEP35 && il >= 0) {
|
|
1056
|
+
const float limit = hparams.swiglu_clamp_shexp[il];
|
|
1057
|
+
constexpr float eps = 1e-6f;
|
|
1058
|
+
if (limit > eps) {
|
|
1059
|
+
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
|
1060
|
+
cb(gate_act, "ffn_silu", il);
|
|
1061
|
+
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
|
1062
|
+
cb(gate_act, "ffn_silu_clamped", il);
|
|
1063
|
+
|
|
1064
|
+
tmp = ggml_clamp(ctx0, tmp, -limit, limit);
|
|
1065
|
+
cb(tmp, "ffn_up_clamped", il);
|
|
1066
|
+
|
|
1067
|
+
cur = ggml_mul(ctx0, gate_act, tmp);
|
|
1068
|
+
cb(cur, "ffn_swiglu_limited", il);
|
|
1069
|
+
type_gate = LLM_FFN_SEQ;
|
|
1070
|
+
break;
|
|
1071
|
+
}
|
|
1072
|
+
}
|
|
1073
|
+
|
|
891
1074
|
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
|
892
1075
|
cb(cur, "ffn_swiglu", il);
|
|
893
1076
|
type_gate = LLM_FFN_SEQ;
|
|
@@ -951,8 +1134,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
951
1134
|
|
|
952
1135
|
if (down) {
|
|
953
1136
|
cur = build_lora_mm(down, cur);
|
|
954
|
-
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
955
|
-
// GLM4 and
|
|
1137
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
|
1138
|
+
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
|
|
956
1139
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
957
1140
|
}
|
|
958
1141
|
}
|
|
@@ -984,11 +1167,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
984
1167
|
int64_t n_expert_used,
|
|
985
1168
|
llm_ffn_op_type type_op,
|
|
986
1169
|
bool norm_w,
|
|
987
|
-
bool scale_w,
|
|
988
1170
|
float w_scale,
|
|
989
1171
|
llama_expert_gating_func_type gating_op,
|
|
990
1172
|
int il,
|
|
991
|
-
ggml_tensor * probs_in
|
|
1173
|
+
ggml_tensor * probs_in,
|
|
1174
|
+
ggml_tensor * gate_up_exps,
|
|
1175
|
+
ggml_tensor * up_exps_s,
|
|
1176
|
+
ggml_tensor * gate_exps_s,
|
|
1177
|
+
ggml_tensor * down_exps_s) const {
|
|
992
1178
|
return build_moe_ffn(
|
|
993
1179
|
cur,
|
|
994
1180
|
gate_inp, /* gate_inp_b */ nullptr,
|
|
@@ -1000,11 +1186,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1000
1186
|
n_expert_used,
|
|
1001
1187
|
type_op,
|
|
1002
1188
|
norm_w,
|
|
1003
|
-
scale_w,
|
|
1004
1189
|
w_scale,
|
|
1005
1190
|
gating_op,
|
|
1006
1191
|
il,
|
|
1007
|
-
probs_in
|
|
1192
|
+
probs_in,
|
|
1193
|
+
gate_up_exps,
|
|
1194
|
+
/* gate_up_exps_b */ nullptr,
|
|
1195
|
+
up_exps_s,
|
|
1196
|
+
gate_exps_s,
|
|
1197
|
+
down_exps_s
|
|
1008
1198
|
);
|
|
1009
1199
|
}
|
|
1010
1200
|
|
|
@@ -1023,11 +1213,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1023
1213
|
int64_t n_expert_used,
|
|
1024
1214
|
llm_ffn_op_type type_op,
|
|
1025
1215
|
bool norm_w,
|
|
1026
|
-
bool scale_w,
|
|
1027
1216
|
float w_scale,
|
|
1028
1217
|
llama_expert_gating_func_type gating_op,
|
|
1029
1218
|
int il,
|
|
1030
|
-
ggml_tensor * probs_in
|
|
1219
|
+
ggml_tensor * probs_in,
|
|
1220
|
+
ggml_tensor * gate_up_exps,
|
|
1221
|
+
ggml_tensor * gate_up_exps_b,
|
|
1222
|
+
ggml_tensor * up_exps_s,
|
|
1223
|
+
ggml_tensor * gate_exps_s,
|
|
1224
|
+
ggml_tensor * down_exps_s) const {
|
|
1031
1225
|
const int64_t n_embd = cur->ne[0];
|
|
1032
1226
|
const int64_t n_tokens = cur->ne[1];
|
|
1033
1227
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
|
@@ -1149,7 +1343,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1149
1343
|
|
|
1150
1344
|
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
|
1151
1345
|
}
|
|
1152
|
-
if (
|
|
1346
|
+
if (w_scale != 0.0f && w_scale != 1.0f) {
|
|
1153
1347
|
weights = ggml_scale(ctx0, weights, w_scale);
|
|
1154
1348
|
cb(weights, "ffn_moe_weights_scaled", il);
|
|
1155
1349
|
}
|
|
@@ -1166,30 +1360,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1166
1360
|
cb(cur, "ffn_moe_weighted", il);
|
|
1167
1361
|
}
|
|
1168
1362
|
|
|
1169
|
-
ggml_tensor * up =
|
|
1170
|
-
|
|
1363
|
+
ggml_tensor * up = nullptr;
|
|
1364
|
+
ggml_tensor * experts = nullptr;
|
|
1171
1365
|
|
|
1172
|
-
if (
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1366
|
+
if (gate_up_exps) {
|
|
1367
|
+
// merged gate_up path: one mul_mat_id, then split into gate and up views
|
|
1368
|
+
ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
|
|
1369
|
+
cb(gate_up, "ffn_moe_gate_up", il);
|
|
1176
1370
|
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1371
|
+
if (gate_up_exps_b) {
|
|
1372
|
+
gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
|
|
1373
|
+
cb(gate_up, "ffn_moe_gate_up_biased", il);
|
|
1374
|
+
}
|
|
1375
|
+
|
|
1376
|
+
// apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
|
|
1377
|
+
if (up_exps_s) {
|
|
1378
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
|
|
1379
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1380
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1381
|
+
gate_up = ggml_mul(ctx0, gate_up, s);
|
|
1382
|
+
cb(gate_up, "ffn_moe_gate_up_scaled", il);
|
|
1383
|
+
}
|
|
1384
|
+
|
|
1385
|
+
const int64_t n_ff = gate_up->ne[0] / 2;
|
|
1386
|
+
cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
|
|
1180
1387
|
cb(cur, "ffn_moe_gate", il);
|
|
1388
|
+
up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
|
|
1389
|
+
cb(up, "ffn_moe_up", il);
|
|
1181
1390
|
} else {
|
|
1182
|
-
|
|
1183
|
-
|
|
1391
|
+
// separate gate and up path
|
|
1392
|
+
up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
1393
|
+
cb(up, "ffn_moe_up", il);
|
|
1394
|
+
|
|
1395
|
+
if (up_exps_b) {
|
|
1396
|
+
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
|
1397
|
+
cb(up, "ffn_moe_up_biased", il);
|
|
1398
|
+
}
|
|
1399
|
+
|
|
1400
|
+
// apply per-expert scale2 to up
|
|
1401
|
+
if (up_exps_s) {
|
|
1402
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
|
|
1403
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1404
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1405
|
+
up = ggml_mul(ctx0, up, s);
|
|
1406
|
+
cb(up, "ffn_moe_up_scaled", il);
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
if (gate_exps) {
|
|
1410
|
+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
1411
|
+
cb(cur, "ffn_moe_gate", il);
|
|
1412
|
+
} else {
|
|
1413
|
+
cur = up;
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
if (gate_exps_b) {
|
|
1417
|
+
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
|
1418
|
+
cb(cur, "ffn_moe_gate_biased", il);
|
|
1419
|
+
}
|
|
1184
1420
|
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1421
|
+
// apply per-expert scale2 to gate
|
|
1422
|
+
if (gate_exps_s) {
|
|
1423
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
|
|
1424
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1425
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1426
|
+
cur = ggml_mul(ctx0, cur, s);
|
|
1427
|
+
cb(cur, "ffn_moe_gate_scaled", il);
|
|
1428
|
+
}
|
|
1188
1429
|
}
|
|
1189
1430
|
|
|
1431
|
+
const bool has_gate = gate_exps || gate_up_exps;
|
|
1432
|
+
|
|
1190
1433
|
switch (type_op) {
|
|
1191
1434
|
case LLM_FFN_SILU:
|
|
1192
1435
|
if (gate_exps) {
|
|
1436
|
+
// Step35: per-layer clamp for routed experts
|
|
1437
|
+
if (arch == LLM_ARCH_STEP35 && il >= 0) {
|
|
1438
|
+
const float limit = hparams.swiglu_clamp_exp[il];
|
|
1439
|
+
constexpr float eps = 1e-6f;
|
|
1440
|
+
if (limit > eps) {
|
|
1441
|
+
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
|
|
1442
|
+
cb(gate_act, "ffn_moe_silu", il);
|
|
1443
|
+
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
|
|
1444
|
+
cb(gate_act, "ffn_moe_silu_clamped", il);
|
|
1445
|
+
|
|
1446
|
+
up = ggml_clamp(ctx0, up, -limit, limit);
|
|
1447
|
+
cb(up, "ffn_moe_up_clamped", il);
|
|
1448
|
+
|
|
1449
|
+
cur = ggml_mul(ctx0, gate_act, up);
|
|
1450
|
+
cb(cur, "ffn_moe_swiglu_limited", il);
|
|
1451
|
+
break;
|
|
1452
|
+
}
|
|
1453
|
+
}
|
|
1454
|
+
}
|
|
1455
|
+
|
|
1456
|
+
if (has_gate) {
|
|
1193
1457
|
cur = ggml_swiglu_split(ctx0, cur, up);
|
|
1194
1458
|
cb(cur, "ffn_moe_swiglu", il);
|
|
1195
1459
|
} else {
|
|
@@ -1197,7 +1461,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1197
1461
|
cb(cur, "ffn_moe_silu", il);
|
|
1198
1462
|
} break;
|
|
1199
1463
|
case LLM_FFN_GELU:
|
|
1200
|
-
if (
|
|
1464
|
+
if (has_gate) {
|
|
1201
1465
|
cur = ggml_geglu_split(ctx0, cur, up);
|
|
1202
1466
|
cb(cur, "ffn_moe_geglu", il);
|
|
1203
1467
|
} else {
|
|
@@ -1213,7 +1477,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1213
1477
|
cb(cur, "ffn_moe_swiglu_oai", il);
|
|
1214
1478
|
} break;
|
|
1215
1479
|
case LLM_FFN_RELU:
|
|
1216
|
-
if (
|
|
1480
|
+
if (has_gate) {
|
|
1217
1481
|
cur = ggml_reglu_split(ctx0, cur, up);
|
|
1218
1482
|
cb(cur, "ffn_moe_reglu", il);
|
|
1219
1483
|
} else {
|
|
@@ -1221,7 +1485,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1221
1485
|
cb(cur, "ffn_moe_relu", il);
|
|
1222
1486
|
} break;
|
|
1223
1487
|
case LLM_FFN_RELU_SQR:
|
|
1224
|
-
if (
|
|
1488
|
+
if (has_gate) {
|
|
1225
1489
|
// TODO: add support for gated squared relu
|
|
1226
1490
|
GGML_ABORT("fatal error: gated squared relu not implemented");
|
|
1227
1491
|
} else {
|
|
@@ -1241,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1241
1505
|
cb(experts, "ffn_moe_down_biased", il);
|
|
1242
1506
|
}
|
|
1243
1507
|
|
|
1508
|
+
// apply per-expert scale2 to down
|
|
1509
|
+
if (down_exps_s) {
|
|
1510
|
+
ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
|
|
1511
|
+
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
|
1512
|
+
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
|
1513
|
+
experts = ggml_mul(ctx0, experts, s);
|
|
1514
|
+
cb(experts, "ffn_moe_down_scaled", il);
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1244
1517
|
if (!weight_before_ffn) {
|
|
1245
1518
|
experts = ggml_mul(ctx0, experts, weights);
|
|
1246
1519
|
cb(cur, "ffn_moe_weighted", il);
|
|
@@ -1279,17 +1552,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
1279
1552
|
|
|
1280
1553
|
// input embeddings with optional lora
|
|
1281
1554
|
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
1282
|
-
const int64_t
|
|
1555
|
+
const int64_t n_embd_inp = hparams.n_embd_inp();
|
|
1556
|
+
const int64_t n_embd = hparams.n_embd;
|
|
1283
1557
|
|
|
1284
|
-
|
|
1558
|
+
assert(n_embd_inp >= n_embd);
|
|
1285
1559
|
|
|
1286
|
-
|
|
1560
|
+
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
|
|
1287
1561
|
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1562
|
+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
|
1563
|
+
cb(inp->tokens, "inp_tokens", -1);
|
|
1564
|
+
ggml_set_input(inp->tokens);
|
|
1565
|
+
res->t_inp_tokens = inp->tokens;
|
|
1566
|
+
|
|
1567
|
+
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
|
|
1568
|
+
cb(inp->embd, "inp_embd", -1);
|
|
1569
|
+
ggml_set_input(inp->embd);
|
|
1570
|
+
|
|
1571
|
+
// select one of the 2 inputs, based on the batch contents
|
|
1572
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
|
1573
|
+
std::array<ggml_tensor *, 2> inps;
|
|
1574
|
+
|
|
1575
|
+
// token embeddings path (ubatch.token != nullptr)
|
|
1576
|
+
{
|
|
1577
|
+
auto & cur = inps[0];
|
|
1293
1578
|
|
|
1294
1579
|
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
|
1295
1580
|
|
|
@@ -1310,19 +1595,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
1310
1595
|
|
|
1311
1596
|
cur = ggml_add(ctx0, cur, inpL_delta);
|
|
1312
1597
|
}
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1598
|
+
|
|
1599
|
+
if (n_embd_inp != n_embd) {
|
|
1600
|
+
cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
|
|
1601
|
+
}
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
// vector embeddings path (ubatch.embd != nullptr)
|
|
1605
|
+
{
|
|
1606
|
+
auto & cur = inps[1];
|
|
1316
1607
|
|
|
1317
1608
|
cur = inp->embd;
|
|
1318
1609
|
}
|
|
1319
1610
|
|
|
1611
|
+
assert(ggml_are_same_shape (inps[0], inps[1]));
|
|
1612
|
+
assert(ggml_are_same_stride(inps[0], inps[1]));
|
|
1613
|
+
|
|
1614
|
+
ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
|
|
1615
|
+
|
|
1616
|
+
if (n_embd_inp != n_embd) {
|
|
1617
|
+
cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
|
|
1618
|
+
}
|
|
1619
|
+
|
|
1620
|
+
res->t_inp_embd = cur;
|
|
1621
|
+
|
|
1320
1622
|
// For Granite architecture
|
|
1321
1623
|
if (hparams.f_embedding_scale != 0.0f) {
|
|
1322
1624
|
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
|
|
1323
1625
|
}
|
|
1324
1626
|
|
|
1325
|
-
cb(cur, "
|
|
1627
|
+
cb(cur, "embd", -1);
|
|
1326
1628
|
|
|
1327
1629
|
res->add_input(std::move(inp));
|
|
1328
1630
|
|
|
@@ -1354,6 +1656,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
|
1354
1656
|
// this need to be 1x1xN for broadcasting
|
|
1355
1657
|
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
|
|
1356
1658
|
ggml_set_input(cur);
|
|
1659
|
+
ggml_set_name(cur, "attn_scale");
|
|
1357
1660
|
|
|
1358
1661
|
res->add_input(std::move(inp));
|
|
1359
1662
|
|
|
@@ -1363,7 +1666,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
|
1363
1666
|
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
|
1364
1667
|
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
|
1365
1668
|
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
|
1366
|
-
// features that require constant topology such as
|
|
1669
|
+
// features that require constant topology such as pipeline parallelism
|
|
1367
1670
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
|
1368
1671
|
//if (n_outputs < n_tokens) {
|
|
1369
1672
|
// return nullptr;
|
|
@@ -1421,7 +1724,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|
|
1421
1724
|
//}
|
|
1422
1725
|
|
|
1423
1726
|
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
|
1424
|
-
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc
|
|
1727
|
+
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1425
1728
|
|
|
1426
1729
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
|
1427
1730
|
ggml_set_input(cur);
|
|
@@ -1499,7 +1802,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1499
1802
|
|
|
1500
1803
|
ggml_tensor * cur;
|
|
1501
1804
|
|
|
1502
|
-
|
|
1805
|
+
const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
|
|
1806
|
+
if (use_flash_attn) {
|
|
1503
1807
|
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
|
1504
1808
|
|
|
1505
1809
|
if (v_trans) {
|
|
@@ -1525,7 +1829,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1525
1829
|
if (v_mla) {
|
|
1526
1830
|
#if 0
|
|
1527
1831
|
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
|
1528
|
-
// However, the code is optimized for dimensions 0 and 1 being large, so this is
|
|
1832
|
+
// However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
|
|
1529
1833
|
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
|
1530
1834
|
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1531
1835
|
#else
|
|
@@ -1695,14 +1999,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|
|
1695
1999
|
{
|
|
1696
2000
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
|
1697
2001
|
|
|
1698
|
-
const auto n_kv = mctx_cur->get_n_kv();
|
|
1699
|
-
const auto n_tokens = ubatch.n_tokens;
|
|
1700
|
-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1701
|
-
|
|
1702
2002
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1703
2003
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1704
2004
|
|
|
1705
|
-
inp->self_kq_mask =
|
|
2005
|
+
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
|
2006
|
+
|
|
1706
2007
|
ggml_set_input(inp->self_kq_mask);
|
|
1707
2008
|
|
|
1708
2009
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1728,9 +2029,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1728
2029
|
ggml_tensor * v_cur,
|
|
1729
2030
|
ggml_tensor * kq_b,
|
|
1730
2031
|
ggml_tensor * sinks,
|
|
1731
|
-
ggml_tensor * v_mla,
|
|
2032
|
+
ggml_tensor * v_mla, // TODO: remove
|
|
1732
2033
|
float kq_scale,
|
|
1733
2034
|
int il) const {
|
|
2035
|
+
GGML_ASSERT(v_mla == nullptr);
|
|
2036
|
+
|
|
1734
2037
|
// these nodes are added to the graph together so that they are not reordered
|
|
1735
2038
|
// by doing so, the number of splits in the graph is reduced
|
|
1736
2039
|
// expand k later to enable rope fusion which directly writes into k-v cache
|
|
@@ -1758,6 +2061,89 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1758
2061
|
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
1759
2062
|
cb(cur, "kqv_out", il);
|
|
1760
2063
|
|
|
2064
|
+
if (wo) {
|
|
2065
|
+
cur = build_lora_mm(wo, cur);
|
|
2066
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
|
2067
|
+
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
|
|
2068
|
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
2069
|
+
}
|
|
2070
|
+
}
|
|
2071
|
+
|
|
2072
|
+
if (wo_b) {
|
|
2073
|
+
cur = ggml_add(ctx0, cur, wo_b);
|
|
2074
|
+
}
|
|
2075
|
+
|
|
2076
|
+
return cur;
|
|
2077
|
+
}
|
|
2078
|
+
|
|
2079
|
+
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|
2080
|
+
ggml_context * ctx0,
|
|
2081
|
+
const llama_ubatch & ubatch,
|
|
2082
|
+
const llama_hparams & hparams,
|
|
2083
|
+
const llama_cparams & cparams,
|
|
2084
|
+
const llama_kv_cache_context * mctx_cur) {
|
|
2085
|
+
|
|
2086
|
+
auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
|
|
2087
|
+
|
|
2088
|
+
{
|
|
2089
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
|
2090
|
+
|
|
2091
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
2092
|
+
|
|
2093
|
+
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
|
2094
|
+
ggml_set_input(inp->self_kq_mask);
|
|
2095
|
+
|
|
2096
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
2097
|
+
}
|
|
2098
|
+
|
|
2099
|
+
return inp;
|
|
2100
|
+
}
|
|
2101
|
+
|
|
2102
|
+
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
|
|
2103
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
|
2104
|
+
|
|
2105
|
+
auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
|
2106
|
+
|
|
2107
|
+
return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
|
|
2108
|
+
}
|
|
2109
|
+
|
|
2110
|
+
ggml_tensor * llm_graph_context::build_attn(
|
|
2111
|
+
llm_graph_input_attn_k * inp,
|
|
2112
|
+
ggml_tensor * wo,
|
|
2113
|
+
ggml_tensor * wo_b,
|
|
2114
|
+
ggml_tensor * q_cur,
|
|
2115
|
+
ggml_tensor * k_cur,
|
|
2116
|
+
ggml_tensor * v_cur,
|
|
2117
|
+
ggml_tensor * kq_b,
|
|
2118
|
+
ggml_tensor * sinks,
|
|
2119
|
+
ggml_tensor * v_mla,
|
|
2120
|
+
float kq_scale,
|
|
2121
|
+
int il) const {
|
|
2122
|
+
// these nodes are added to the graph together so that they are not reordered
|
|
2123
|
+
// by doing so, the number of splits in the graph is reduced
|
|
2124
|
+
// expand k later to enable rope fusion which directly writes into k-v cache
|
|
2125
|
+
ggml_build_forward_expand(gf, q_cur);
|
|
2126
|
+
ggml_build_forward_expand(gf, v_cur);
|
|
2127
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
2128
|
+
|
|
2129
|
+
const auto * mctx_cur = inp->mctx;
|
|
2130
|
+
|
|
2131
|
+
// store to KV cache
|
|
2132
|
+
{
|
|
2133
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
2134
|
+
|
|
2135
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
2136
|
+
}
|
|
2137
|
+
|
|
2138
|
+
const auto & kq_mask = inp->get_kq_mask();
|
|
2139
|
+
|
|
2140
|
+
ggml_tensor * q = q_cur;
|
|
2141
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
2142
|
+
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
|
2143
|
+
|
|
2144
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
2145
|
+
cb(cur, "kqv_out", il);
|
|
2146
|
+
|
|
1761
2147
|
if (wo) {
|
|
1762
2148
|
cur = build_lora_mm(wo, cur);
|
|
1763
2149
|
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
@@ -1903,15 +2289,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|
|
1903
2289
|
|
|
1904
2290
|
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
|
1905
2291
|
|
|
1906
|
-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1907
|
-
|
|
1908
2292
|
{
|
|
1909
|
-
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1910
|
-
|
|
1911
2293
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1912
2294
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1913
2295
|
|
|
1914
|
-
inp->self_kq_mask =
|
|
2296
|
+
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
|
1915
2297
|
ggml_set_input(inp->self_kq_mask);
|
|
1916
2298
|
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
|
1917
2299
|
|
|
@@ -1922,12 +2304,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|
|
1922
2304
|
{
|
|
1923
2305
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
|
1924
2306
|
|
|
1925
|
-
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1926
|
-
|
|
1927
2307
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1928
2308
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1929
2309
|
|
|
1930
|
-
inp->self_kq_mask_swa =
|
|
2310
|
+
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
|
1931
2311
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
1932
2312
|
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
|
1933
2313
|
|
|
@@ -2068,10 +2448,57 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
|
2068
2448
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
2069
2449
|
}
|
|
2070
2450
|
|
|
2451
|
+
llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
|
|
2452
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
2453
|
+
|
|
2454
|
+
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
|
2455
|
+
auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
2456
|
+
|
|
2457
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
2458
|
+
|
|
2459
|
+
return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
|
|
2460
|
+
}
|
|
2461
|
+
|
|
2462
|
+
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
|
|
2463
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
|
|
2464
|
+
|
|
2465
|
+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
|
2466
|
+
|
|
2467
|
+
// build iswa attention input
|
|
2468
|
+
const auto * attn_ctx = mctx_cur->get_attn();
|
|
2469
|
+
|
|
2470
|
+
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
|
2471
|
+
|
|
2472
|
+
{
|
|
2473
|
+
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
2474
|
+
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
2475
|
+
|
|
2476
|
+
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
|
2477
|
+
ggml_set_input(inp_attn->self_kq_mask);
|
|
2478
|
+
|
|
2479
|
+
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
|
2480
|
+
}
|
|
2481
|
+
|
|
2482
|
+
{
|
|
2483
|
+
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
2484
|
+
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
2485
|
+
|
|
2486
|
+
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
|
2487
|
+
ggml_set_input(inp_attn->self_kq_mask_swa);
|
|
2488
|
+
|
|
2489
|
+
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
2493
|
+
|
|
2494
|
+
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
|
|
2495
|
+
}
|
|
2496
|
+
|
|
2071
2497
|
void llm_graph_context::build_dense_out(
|
|
2072
2498
|
ggml_tensor * dense_2,
|
|
2499
|
+
ggml_tensor * dense_2_b,
|
|
2073
2500
|
ggml_tensor * dense_3) const {
|
|
2074
|
-
if (!cparams.embeddings || !(dense_2 || dense_3)) {
|
|
2501
|
+
if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
|
|
2075
2502
|
return;
|
|
2076
2503
|
}
|
|
2077
2504
|
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
|
@@ -2080,6 +2507,9 @@ void llm_graph_context::build_dense_out(
|
|
|
2080
2507
|
if (dense_2) {
|
|
2081
2508
|
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
|
2082
2509
|
}
|
|
2510
|
+
if (dense_2_b) {
|
|
2511
|
+
cur = ggml_add(ctx0, cur, dense_2_b);
|
|
2512
|
+
}
|
|
2083
2513
|
if (dense_3) {
|
|
2084
2514
|
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
|
2085
2515
|
}
|
|
@@ -2093,7 +2523,8 @@ void llm_graph_context::build_pooling(
|
|
|
2093
2523
|
ggml_tensor * cls,
|
|
2094
2524
|
ggml_tensor * cls_b,
|
|
2095
2525
|
ggml_tensor * cls_out,
|
|
2096
|
-
ggml_tensor * cls_out_b
|
|
2526
|
+
ggml_tensor * cls_out_b,
|
|
2527
|
+
ggml_tensor * cls_norm) const {
|
|
2097
2528
|
if (!cparams.embeddings) {
|
|
2098
2529
|
return;
|
|
2099
2530
|
}
|
|
@@ -2132,8 +2563,15 @@ void llm_graph_context::build_pooling(
|
|
|
2132
2563
|
} break;
|
|
2133
2564
|
case LLAMA_POOLING_TYPE_RANK:
|
|
2134
2565
|
{
|
|
2135
|
-
|
|
2136
|
-
|
|
2566
|
+
if (arch == LLM_ARCH_MODERN_BERT) {
|
|
2567
|
+
// modern bert gte reranker builds mean first then applies prediction head and classifier
|
|
2568
|
+
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
|
|
2569
|
+
ggml_tensor * inp_mean = build_inp_mean();
|
|
2570
|
+
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
|
2571
|
+
} else {
|
|
2572
|
+
ggml_tensor * inp_cls = build_inp_cls();
|
|
2573
|
+
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
|
2574
|
+
}
|
|
2137
2575
|
|
|
2138
2576
|
// classification head
|
|
2139
2577
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
@@ -2142,7 +2580,15 @@ void llm_graph_context::build_pooling(
|
|
|
2142
2580
|
if (cls_b) {
|
|
2143
2581
|
cur = ggml_add(ctx0, cur, cls_b);
|
|
2144
2582
|
}
|
|
2145
|
-
|
|
2583
|
+
if (arch == LLM_ARCH_MODERN_BERT) {
|
|
2584
|
+
cur = ggml_gelu(ctx0, cur);
|
|
2585
|
+
} else {
|
|
2586
|
+
cur = ggml_tanh(ctx0, cur);
|
|
2587
|
+
}
|
|
2588
|
+
if (cls_norm) {
|
|
2589
|
+
// head norm
|
|
2590
|
+
cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
|
|
2591
|
+
}
|
|
2146
2592
|
}
|
|
2147
2593
|
|
|
2148
2594
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
@@ -2157,7 +2603,7 @@ void llm_graph_context::build_pooling(
|
|
|
2157
2603
|
}
|
|
2158
2604
|
|
|
2159
2605
|
// softmax for qwen3 reranker
|
|
2160
|
-
if (arch == LLM_ARCH_QWEN3) {
|
|
2606
|
+
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
|
|
2161
2607
|
cur = ggml_soft_max(ctx0, cur);
|
|
2162
2608
|
}
|
|
2163
2609
|
} break;
|
|
@@ -2178,6 +2624,9 @@ void llm_graph_context::build_sampling() const {
|
|
|
2178
2624
|
return;
|
|
2179
2625
|
}
|
|
2180
2626
|
|
|
2627
|
+
std::array<ggml_tensor *, 2> outs;
|
|
2628
|
+
outs[0] = res->t_logits;
|
|
2629
|
+
|
|
2181
2630
|
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
|
2182
2631
|
res->add_input(std::move(inp_sampling));
|
|
2183
2632
|
|
|
@@ -2198,14 +2647,14 @@ void llm_graph_context::build_sampling() const {
|
|
|
2198
2647
|
// add a dummy row of logits
|
|
2199
2648
|
// this trick makes the graph static, regardless of which samplers are activated
|
|
2200
2649
|
// this is important in order to minimize graph reallocations
|
|
2201
|
-
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
|
|
2202
2650
|
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
|
2203
2651
|
|
|
2204
2652
|
for (const auto & [seq_id, sampler] : samplers) {
|
|
2205
2653
|
const auto it = seq_to_logit_row.find(seq_id);
|
|
2206
2654
|
|
|
2207
2655
|
// inactive samplers always work on the first row
|
|
2208
|
-
const auto row_idx =
|
|
2656
|
+
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
|
|
2657
|
+
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
|
|
2209
2658
|
|
|
2210
2659
|
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
|
2211
2660
|
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
|
@@ -2222,22 +2671,26 @@ void llm_graph_context::build_sampling() const {
|
|
|
2222
2671
|
|
|
2223
2672
|
if (data.sampled != nullptr) {
|
|
2224
2673
|
res->t_sampled[seq_id] = data.sampled;
|
|
2225
|
-
|
|
2674
|
+
outs[1] = data.sampled;
|
|
2675
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2226
2676
|
}
|
|
2227
2677
|
|
|
2228
2678
|
if (data.probs != nullptr) {
|
|
2229
2679
|
res->t_sampled_probs[seq_id] = data.probs;
|
|
2230
|
-
|
|
2680
|
+
outs[1] = data.probs;
|
|
2681
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2231
2682
|
}
|
|
2232
2683
|
|
|
2233
2684
|
if (data.logits != nullptr) {
|
|
2234
2685
|
res->t_sampled_logits[seq_id] = data.logits;
|
|
2235
|
-
|
|
2686
|
+
outs[1] = data.logits;
|
|
2687
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2236
2688
|
}
|
|
2237
2689
|
|
|
2238
2690
|
if (data.candidates != nullptr) {
|
|
2239
2691
|
res->t_candidates[seq_id] = data.candidates;
|
|
2240
|
-
|
|
2692
|
+
outs[1] = data.candidates;
|
|
2693
|
+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
|
2241
2694
|
}
|
|
2242
2695
|
}
|
|
2243
2696
|
|