whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -4,14 +4,15 @@
|
|
|
4
4
|
#include "llama-batch.h"
|
|
5
5
|
#include "llama-cparams.h"
|
|
6
6
|
|
|
7
|
-
#include "llama-kv-cache
|
|
8
|
-
#include "llama-kv-cache-
|
|
7
|
+
#include "llama-kv-cache.h"
|
|
8
|
+
#include "llama-kv-cache-iswa.h"
|
|
9
9
|
#include "llama-memory-hybrid.h"
|
|
10
10
|
#include "llama-memory-recurrent.h"
|
|
11
11
|
|
|
12
12
|
#include <cassert>
|
|
13
13
|
#include <cmath>
|
|
14
14
|
#include <cstring>
|
|
15
|
+
#include <unordered_set>
|
|
15
16
|
|
|
16
17
|
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
17
18
|
if (ubatch->token) {
|
|
@@ -28,6 +29,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
28
29
|
}
|
|
29
30
|
}
|
|
30
31
|
|
|
32
|
+
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|
33
|
+
bool res = true;
|
|
34
|
+
|
|
35
|
+
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
|
36
|
+
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
|
37
|
+
|
|
38
|
+
return res;
|
|
39
|
+
}
|
|
40
|
+
|
|
31
41
|
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
32
42
|
if (ubatch->pos && pos) {
|
|
33
43
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -50,15 +60,26 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
|
50
60
|
}
|
|
51
61
|
}
|
|
52
62
|
|
|
63
|
+
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
|
64
|
+
bool res = true;
|
|
65
|
+
|
|
66
|
+
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
|
|
67
|
+
|
|
68
|
+
return res;
|
|
69
|
+
}
|
|
70
|
+
|
|
53
71
|
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|
54
72
|
if (ubatch->pos && attn_scale) {
|
|
55
73
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
56
74
|
|
|
75
|
+
GGML_ASSERT(f_attn_temp_scale != 0.0f);
|
|
76
|
+
GGML_ASSERT(n_attn_temp_floor_scale != 0);
|
|
77
|
+
|
|
57
78
|
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
|
58
79
|
for (int i = 0; i < n_tokens; ++i) {
|
|
59
80
|
const float pos = ubatch->pos[i];
|
|
60
81
|
attn_scale_data[i] = std::log(
|
|
61
|
-
std::floor((pos +
|
|
82
|
+
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
|
|
62
83
|
) * f_attn_temp_scale + 1.0;
|
|
63
84
|
}
|
|
64
85
|
|
|
@@ -71,7 +92,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
71
92
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
72
93
|
|
|
73
94
|
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
|
74
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
95
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
|
75
96
|
|
|
76
97
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
77
98
|
|
|
@@ -118,6 +139,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
|
|
118
139
|
}
|
|
119
140
|
}
|
|
120
141
|
|
|
142
|
+
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
|
143
|
+
bool res = true;
|
|
144
|
+
|
|
145
|
+
res &= n_outputs == params.n_outputs;
|
|
146
|
+
|
|
147
|
+
return res;
|
|
148
|
+
}
|
|
149
|
+
|
|
121
150
|
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
122
151
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
|
123
152
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -163,38 +192,26 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
|
163
192
|
|
|
164
193
|
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
165
194
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
166
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
167
195
|
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
168
196
|
|
|
169
197
|
if (cparams.embeddings && (
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
198
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
|
199
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
|
|
200
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
|
|
201
|
+
)) {
|
|
173
202
|
GGML_ASSERT(cls);
|
|
174
203
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
175
204
|
|
|
176
205
|
uint32_t * data = (uint32_t *) cls->data;
|
|
177
206
|
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
178
207
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
182
|
-
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
183
|
-
|
|
184
|
-
data[seq_idx] = i;
|
|
185
|
-
}
|
|
186
|
-
}
|
|
187
|
-
}
|
|
188
|
-
|
|
189
|
-
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
|
190
|
-
GGML_ASSERT(cls);
|
|
191
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
192
|
-
|
|
193
|
-
uint32_t * data = (uint32_t *) cls->data;
|
|
194
|
-
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
208
|
+
std::vector<int> target_pos(n_seqs_unq, -1);
|
|
209
|
+
std::vector<int> target_row(n_seqs_unq, -1);
|
|
195
210
|
|
|
196
|
-
|
|
197
|
-
|
|
211
|
+
const bool last = (
|
|
212
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
|
213
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
|
214
|
+
);
|
|
198
215
|
|
|
199
216
|
for (int i = 0; i < n_tokens; ++i) {
|
|
200
217
|
const llama_pos pos = ubatch->pos[i];
|
|
@@ -203,16 +220,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
203
220
|
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
204
221
|
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
205
222
|
|
|
206
|
-
if (
|
|
207
|
-
|
|
208
|
-
|
|
223
|
+
if (
|
|
224
|
+
(target_pos[seq_idx] == -1) ||
|
|
225
|
+
( last && pos >= target_pos[seq_idx]) ||
|
|
226
|
+
(!last && pos < target_pos[seq_idx])
|
|
227
|
+
) {
|
|
228
|
+
target_pos[seq_idx] = pos;
|
|
229
|
+
target_row[seq_idx] = i;
|
|
209
230
|
}
|
|
210
231
|
}
|
|
211
232
|
}
|
|
212
233
|
|
|
213
234
|
for (int s = 0; s < n_seqs_unq; ++s) {
|
|
214
|
-
if (
|
|
215
|
-
data[s] =
|
|
235
|
+
if (target_row[s] >= 0) {
|
|
236
|
+
data[s] = target_row[s];
|
|
216
237
|
}
|
|
217
238
|
}
|
|
218
239
|
}
|
|
@@ -234,6 +255,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
|
234
255
|
}
|
|
235
256
|
}
|
|
236
257
|
|
|
258
|
+
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
|
259
|
+
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
|
|
260
|
+
|
|
261
|
+
this->mctx = mctx;
|
|
262
|
+
|
|
263
|
+
bool res = true;
|
|
264
|
+
|
|
265
|
+
res &= s_copy->ne[0] == mctx->get_n_rs();
|
|
266
|
+
|
|
267
|
+
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
268
|
+
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
|
269
|
+
|
|
270
|
+
res &= head == mctx->get_head();
|
|
271
|
+
res &= rs_z == mctx->get_rs_z();
|
|
272
|
+
|
|
273
|
+
return res;
|
|
274
|
+
}
|
|
275
|
+
|
|
237
276
|
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
238
277
|
GGML_UNUSED(ubatch);
|
|
239
278
|
|
|
@@ -244,56 +283,164 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
244
283
|
}
|
|
245
284
|
}
|
|
246
285
|
|
|
286
|
+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
|
287
|
+
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
|
288
|
+
const char * swa_type_str = "unknown";
|
|
289
|
+
|
|
290
|
+
switch (swa_type) {
|
|
291
|
+
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
|
|
292
|
+
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
|
|
293
|
+
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
|
|
294
|
+
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
|
295
|
+
};
|
|
296
|
+
|
|
297
|
+
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
|
298
|
+
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
|
299
|
+
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
|
300
|
+
|
|
301
|
+
LLAMA_LOG_DEBUG(" ");
|
|
302
|
+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
|
303
|
+
LLAMA_LOG_DEBUG("%2d", j);
|
|
304
|
+
}
|
|
305
|
+
LLAMA_LOG_DEBUG("\n");
|
|
306
|
+
|
|
307
|
+
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
|
|
308
|
+
LLAMA_LOG_DEBUG(" %2d ", i);
|
|
309
|
+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
|
310
|
+
float val = data[i * n_kv + j];
|
|
311
|
+
if (val == -INFINITY) {
|
|
312
|
+
LLAMA_LOG_DEBUG(" ∞");
|
|
313
|
+
} else {
|
|
314
|
+
LLAMA_LOG_DEBUG(" 0");
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
LLAMA_LOG_DEBUG("\n");
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
247
321
|
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
248
322
|
const int64_t n_kv = ubatch->n_tokens;
|
|
249
323
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
250
324
|
|
|
251
|
-
|
|
252
|
-
|
|
325
|
+
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
326
|
+
for (int h = 0; h < 1; ++h) {
|
|
327
|
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
328
|
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
329
|
+
const llama_pos p1 = ubatch->pos[i1];
|
|
253
330
|
|
|
254
|
-
|
|
331
|
+
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
|
255
332
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
333
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
334
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
335
|
+
const llama_pos p0 = ubatch->pos[i0];
|
|
259
336
|
|
|
260
|
-
|
|
261
|
-
|
|
337
|
+
// mask different sequences
|
|
338
|
+
if (s0 != s1) {
|
|
339
|
+
continue;
|
|
340
|
+
}
|
|
262
341
|
|
|
263
|
-
|
|
264
|
-
|
|
342
|
+
// mask future tokens
|
|
343
|
+
if (cparams.causal_attn && p0 > p1) {
|
|
344
|
+
continue;
|
|
345
|
+
}
|
|
265
346
|
|
|
266
|
-
//
|
|
267
|
-
if (
|
|
268
|
-
|
|
269
|
-
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
|
270
|
-
} else {
|
|
271
|
-
f = 0.0f;
|
|
272
|
-
}
|
|
273
|
-
break;
|
|
347
|
+
// apply SWA if any
|
|
348
|
+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
349
|
+
continue;
|
|
274
350
|
}
|
|
275
|
-
}
|
|
276
351
|
|
|
277
|
-
|
|
352
|
+
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
353
|
+
}
|
|
278
354
|
}
|
|
279
355
|
}
|
|
356
|
+
};
|
|
357
|
+
|
|
358
|
+
{
|
|
359
|
+
GGML_ASSERT(self_kq_mask);
|
|
360
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
|
361
|
+
|
|
362
|
+
float * data = (float *) self_kq_mask->data;
|
|
363
|
+
|
|
364
|
+
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
|
365
|
+
|
|
366
|
+
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
|
367
|
+
|
|
368
|
+
if (debug) {
|
|
369
|
+
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
|
370
|
+
}
|
|
280
371
|
}
|
|
281
|
-
}
|
|
282
372
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
373
|
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
374
|
+
GGML_ASSERT(self_kq_mask_swa);
|
|
375
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
|
376
|
+
|
|
377
|
+
float * data = (float *) self_kq_mask_swa->data;
|
|
378
|
+
|
|
379
|
+
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
|
380
|
+
|
|
381
|
+
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
|
382
|
+
|
|
383
|
+
if (debug) {
|
|
384
|
+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
|
385
|
+
}
|
|
286
386
|
}
|
|
287
387
|
}
|
|
288
388
|
|
|
289
|
-
void
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
}
|
|
389
|
+
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
|
390
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
391
|
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
|
293
392
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
393
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
|
397
|
+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
|
398
|
+
|
|
399
|
+
this->mctx = mctx;
|
|
400
|
+
|
|
401
|
+
bool res = true;
|
|
402
|
+
|
|
403
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
404
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
405
|
+
|
|
406
|
+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
|
407
|
+
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
|
408
|
+
|
|
409
|
+
return res;
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
|
413
|
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
414
|
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
415
|
+
|
|
416
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
417
|
+
|
|
418
|
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
|
419
|
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
|
420
|
+
|
|
421
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
|
425
|
+
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
|
|
426
|
+
|
|
427
|
+
this->mctx = mctx;
|
|
428
|
+
|
|
429
|
+
bool res = true;
|
|
430
|
+
|
|
431
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
432
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
433
|
+
|
|
434
|
+
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
|
435
|
+
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
436
|
+
|
|
437
|
+
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
|
438
|
+
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
|
439
|
+
|
|
440
|
+
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
|
441
|
+
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
|
442
|
+
|
|
443
|
+
return res;
|
|
297
444
|
}
|
|
298
445
|
|
|
299
446
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
@@ -303,7 +450,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
303
450
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
304
451
|
|
|
305
452
|
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
|
306
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
453
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
|
307
454
|
|
|
308
455
|
float * data = (float *) cross_kq_mask->data;
|
|
309
456
|
|
|
@@ -324,7 +471,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
324
471
|
}
|
|
325
472
|
}
|
|
326
473
|
|
|
327
|
-
for (int i = n_tokens; i <
|
|
474
|
+
for (int i = n_tokens; i < n_tokens; ++i) {
|
|
328
475
|
for (int j = 0; j < n_enc; ++j) {
|
|
329
476
|
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
330
477
|
}
|
|
@@ -333,15 +480,16 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
333
480
|
}
|
|
334
481
|
|
|
335
482
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
483
|
+
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
484
|
+
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
|
485
|
+
|
|
486
|
+
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
339
487
|
|
|
340
488
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
341
489
|
|
|
342
|
-
if (s_copy) {
|
|
343
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
344
|
-
int32_t * data = (int32_t *) s_copy->data;
|
|
490
|
+
if (inp_rs->s_copy) {
|
|
491
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
492
|
+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
345
493
|
|
|
346
494
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
347
495
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
@@ -350,10 +498,186 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
|
350
498
|
}
|
|
351
499
|
}
|
|
352
500
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
501
|
+
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|
502
|
+
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
|
503
|
+
|
|
504
|
+
this->mctx = mctx;
|
|
505
|
+
|
|
506
|
+
bool res = true;
|
|
507
|
+
|
|
508
|
+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
509
|
+
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
510
|
+
|
|
511
|
+
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
|
512
|
+
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
|
513
|
+
|
|
514
|
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
515
|
+
|
|
516
|
+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
517
|
+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
518
|
+
|
|
519
|
+
res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
520
|
+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
521
|
+
|
|
522
|
+
return res;
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
|
526
|
+
// set the inputs only for the active samplers in the current ubatch
|
|
527
|
+
std::unordered_set<llama_seq_id> active_samplers;
|
|
528
|
+
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
|
|
529
|
+
if (ubatch->output[i]) {
|
|
530
|
+
llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
531
|
+
active_samplers.insert(seq_id);
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
for (auto seq_id : active_samplers) {
|
|
536
|
+
if (samplers.find(seq_id) == samplers.end()) {
|
|
537
|
+
continue;
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
auto & sampler = samplers[seq_id];
|
|
541
|
+
|
|
542
|
+
if (sampler->iface->backend_set_input) {
|
|
543
|
+
sampler->iface->backend_set_input(sampler);
|
|
544
|
+
}
|
|
545
|
+
}
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
|
549
|
+
if (samplers.size() != params.samplers.size()) {
|
|
550
|
+
return false;
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
for (const auto & [seq_id, sampler] : params.samplers) {
|
|
554
|
+
if (samplers[seq_id] != sampler) {
|
|
555
|
+
return false;
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
return true;
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
//
|
|
563
|
+
// llm_graph_result
|
|
564
|
+
//
|
|
565
|
+
|
|
566
|
+
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
|
|
567
|
+
reset();
|
|
568
|
+
|
|
569
|
+
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
|
|
570
|
+
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
int64_t llm_graph_result::get_max_nodes() const {
|
|
574
|
+
return max_nodes;
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
void llm_graph_result::reset() {
|
|
578
|
+
t_tokens = nullptr;
|
|
579
|
+
t_logits = nullptr;
|
|
580
|
+
t_embd = nullptr;
|
|
581
|
+
t_embd_pooled = nullptr;
|
|
582
|
+
t_sampled.clear();
|
|
583
|
+
t_sampled_probs.clear();
|
|
584
|
+
t_sampled_logits.clear();
|
|
585
|
+
t_candidates.clear();
|
|
586
|
+
|
|
587
|
+
params = {};
|
|
588
|
+
|
|
589
|
+
inputs.clear();
|
|
590
|
+
|
|
591
|
+
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
|
592
|
+
|
|
593
|
+
ggml_init_params params = {
|
|
594
|
+
/*.mem_size =*/ buf_compute_meta.size(),
|
|
595
|
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
|
596
|
+
/*.no_alloc =*/ true,
|
|
597
|
+
};
|
|
598
|
+
|
|
599
|
+
ctx_compute.reset(ggml_init(params));
|
|
600
|
+
|
|
601
|
+
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
|
605
|
+
for (auto & input : inputs) {
|
|
606
|
+
input->set_input(ubatch);
|
|
607
|
+
}
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
void llm_graph_result::set_outputs() {
|
|
611
|
+
if (t_logits != nullptr) {
|
|
612
|
+
ggml_set_output(t_logits);
|
|
613
|
+
}
|
|
614
|
+
if (t_embd != nullptr) {
|
|
615
|
+
ggml_set_output(t_embd);
|
|
616
|
+
}
|
|
617
|
+
if (t_embd_pooled != nullptr) {
|
|
618
|
+
ggml_set_output(t_embd_pooled);
|
|
619
|
+
}
|
|
620
|
+
for (auto & [seq_id, t] : t_sampled) {
|
|
621
|
+
if (t != nullptr) {
|
|
622
|
+
ggml_set_output(t);
|
|
623
|
+
}
|
|
624
|
+
}
|
|
625
|
+
for (auto & [seq_id, t] : t_sampled_probs) {
|
|
626
|
+
if (t != nullptr) {
|
|
627
|
+
ggml_set_output(t);
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
for (auto & [seq_id, t] : t_sampled_logits) {
|
|
631
|
+
if (t != nullptr) {
|
|
632
|
+
ggml_set_output(t);
|
|
633
|
+
}
|
|
634
|
+
}
|
|
635
|
+
for (auto & [seq_id, t] : t_candidates) {
|
|
636
|
+
if (t != nullptr) {
|
|
637
|
+
ggml_set_output(t);
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
|
643
|
+
if (!this->params.allow_reuse(params)) {
|
|
644
|
+
if (debug > 1) {
|
|
645
|
+
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
return false;
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
if (debug > 1) {
|
|
652
|
+
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
bool res = true;
|
|
656
|
+
|
|
657
|
+
for (auto & input : inputs) {
|
|
658
|
+
const bool cur = input->can_reuse(params);
|
|
659
|
+
|
|
660
|
+
if (debug > 1) {
|
|
661
|
+
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
res = res && cur;
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
if (debug > 0) {
|
|
668
|
+
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
return res;
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
|
|
675
|
+
inputs.emplace_back(std::move(input));
|
|
676
|
+
return inputs.back().get();
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
void llm_graph_result::set_params(const llm_graph_params & params) {
|
|
680
|
+
this->params = params;
|
|
357
681
|
}
|
|
358
682
|
|
|
359
683
|
//
|
|
@@ -390,15 +714,18 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
390
714
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
|
391
715
|
pooling_type (cparams.pooling_type),
|
|
392
716
|
rope_type (hparams.rope_type),
|
|
393
|
-
ctx0 (params.ctx),
|
|
394
717
|
sched (params.sched),
|
|
395
718
|
backend_cpu (params.backend_cpu),
|
|
396
719
|
cvec (params.cvec),
|
|
397
720
|
loras (params.loras),
|
|
398
721
|
mctx (params.mctx),
|
|
399
722
|
cross (params.cross),
|
|
723
|
+
samplers (params.samplers),
|
|
400
724
|
cb_func (params.cb),
|
|
401
|
-
res (
|
|
725
|
+
res (params.res),
|
|
726
|
+
ctx0 (res->get_ctx()),
|
|
727
|
+
gf (res->get_gf()) {
|
|
728
|
+
res->set_params(params);
|
|
402
729
|
}
|
|
403
730
|
|
|
404
731
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
@@ -613,6 +940,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
613
940
|
cur = ggml_reglu(ctx0, cur);
|
|
614
941
|
cb(cur, "ffn_reglu", il);
|
|
615
942
|
} break;
|
|
943
|
+
default:
|
|
944
|
+
GGML_ABORT("fatal error");
|
|
616
945
|
}
|
|
617
946
|
|
|
618
947
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -622,8 +951,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
622
951
|
|
|
623
952
|
if (down) {
|
|
624
953
|
cur = build_lora_mm(down, cur);
|
|
625
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
626
|
-
// GLM4
|
|
954
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
955
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
|
627
956
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
628
957
|
}
|
|
629
958
|
}
|
|
@@ -658,13 +987,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
658
987
|
bool scale_w,
|
|
659
988
|
float w_scale,
|
|
660
989
|
llama_expert_gating_func_type gating_op,
|
|
661
|
-
int il
|
|
990
|
+
int il,
|
|
991
|
+
ggml_tensor * probs_in) const {
|
|
992
|
+
return build_moe_ffn(
|
|
993
|
+
cur,
|
|
994
|
+
gate_inp, /* gate_inp_b */ nullptr,
|
|
995
|
+
up_exps, /* up_exps_b */ nullptr,
|
|
996
|
+
gate_exps, /* gate_exps_b */ nullptr,
|
|
997
|
+
down_exps, /* down_exps_b */ nullptr,
|
|
998
|
+
exp_probs_b,
|
|
999
|
+
n_expert,
|
|
1000
|
+
n_expert_used,
|
|
1001
|
+
type_op,
|
|
1002
|
+
norm_w,
|
|
1003
|
+
scale_w,
|
|
1004
|
+
w_scale,
|
|
1005
|
+
gating_op,
|
|
1006
|
+
il,
|
|
1007
|
+
probs_in
|
|
1008
|
+
);
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1011
|
+
ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
1012
|
+
ggml_tensor * cur,
|
|
1013
|
+
ggml_tensor * gate_inp,
|
|
1014
|
+
ggml_tensor * gate_inp_b,
|
|
1015
|
+
ggml_tensor * up_exps,
|
|
1016
|
+
ggml_tensor * up_exps_b,
|
|
1017
|
+
ggml_tensor * gate_exps,
|
|
1018
|
+
ggml_tensor * gate_exps_b,
|
|
1019
|
+
ggml_tensor * down_exps,
|
|
1020
|
+
ggml_tensor * down_exps_b,
|
|
1021
|
+
ggml_tensor * exp_probs_b,
|
|
1022
|
+
int64_t n_expert,
|
|
1023
|
+
int64_t n_expert_used,
|
|
1024
|
+
llm_ffn_op_type type_op,
|
|
1025
|
+
bool norm_w,
|
|
1026
|
+
bool scale_w,
|
|
1027
|
+
float w_scale,
|
|
1028
|
+
llama_expert_gating_func_type gating_op,
|
|
1029
|
+
int il,
|
|
1030
|
+
ggml_tensor * probs_in) const {
|
|
662
1031
|
const int64_t n_embd = cur->ne[0];
|
|
663
1032
|
const int64_t n_tokens = cur->ne[1];
|
|
664
1033
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
|
665
1034
|
|
|
666
|
-
ggml_tensor * logits =
|
|
667
|
-
|
|
1035
|
+
ggml_tensor * logits = nullptr;
|
|
1036
|
+
|
|
1037
|
+
if (probs_in == nullptr) {
|
|
1038
|
+
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
|
1039
|
+
cb(logits, "ffn_moe_logits", il);
|
|
1040
|
+
} else {
|
|
1041
|
+
logits = probs_in;
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
if (gate_inp_b) {
|
|
1045
|
+
logits = ggml_add(ctx0, logits, gate_inp_b);
|
|
1046
|
+
cb(logits, "ffn_moe_logits_biased", il);
|
|
1047
|
+
}
|
|
668
1048
|
|
|
669
1049
|
ggml_tensor * probs = nullptr;
|
|
670
1050
|
switch (gating_op) {
|
|
@@ -676,6 +1056,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
676
1056
|
{
|
|
677
1057
|
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
|
678
1058
|
} break;
|
|
1059
|
+
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
|
|
1060
|
+
{
|
|
1061
|
+
probs = logits; // [n_expert, n_tokens]
|
|
1062
|
+
} break;
|
|
679
1063
|
default:
|
|
680
1064
|
GGML_ABORT("fatal error");
|
|
681
1065
|
}
|
|
@@ -695,21 +1079,71 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
695
1079
|
selection_probs = logits;
|
|
696
1080
|
}
|
|
697
1081
|
|
|
1082
|
+
if (arch == LLM_ARCH_GROVEMOE) {
|
|
1083
|
+
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
|
1084
|
+
cb(selection_probs, "ffn_moe_probs_biased", il);
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
// select top n_group_used expert groups
|
|
1088
|
+
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
|
|
1089
|
+
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
|
|
1090
|
+
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
|
1091
|
+
|
|
1092
|
+
// organize experts into n_expert_groups
|
|
1093
|
+
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
|
1094
|
+
|
|
1095
|
+
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
|
1096
|
+
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
|
1097
|
+
|
|
1098
|
+
// get top n_group_used expert groups
|
|
1099
|
+
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
|
1100
|
+
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
|
1101
|
+
|
|
1102
|
+
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
|
1103
|
+
cb(expert_groups, "ffn_moe_group_topk", il);
|
|
1104
|
+
|
|
1105
|
+
// mask out the other groups
|
|
1106
|
+
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
|
1107
|
+
selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
|
1108
|
+
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
|
1109
|
+
cb(selection_probs, "ffn_moe_probs_masked", il);
|
|
1110
|
+
}
|
|
1111
|
+
|
|
698
1112
|
// select experts
|
|
699
|
-
ggml_tensor * selected_experts =
|
|
1113
|
+
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
|
700
1114
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
701
1115
|
cb(selected_experts, "ffn_moe_topk", il);
|
|
702
1116
|
|
|
703
|
-
|
|
704
|
-
|
|
1117
|
+
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
|
|
1118
|
+
// TODO: Use scalar div instead when/if implemented
|
|
1119
|
+
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
|
|
1120
|
+
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
|
|
1121
|
+
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
|
|
1122
|
+
} else {
|
|
1123
|
+
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1126
|
+
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
|
|
705
1127
|
cb(weights, "ffn_moe_weights", il);
|
|
706
1128
|
|
|
1129
|
+
|
|
1130
|
+
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
|
1131
|
+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
|
1132
|
+
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
|
|
1133
|
+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
|
1134
|
+
cb(weights, "ffn_moe_weights_softmax", il);
|
|
1135
|
+
}
|
|
1136
|
+
|
|
707
1137
|
if (norm_w) {
|
|
708
1138
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
|
709
1139
|
|
|
710
1140
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
|
711
1141
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
|
712
1142
|
|
|
1143
|
+
// Avoid division by zero, clamp to smallest number representable by F16
|
|
1144
|
+
weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
|
|
1145
|
+
cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
|
|
1146
|
+
|
|
713
1147
|
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
|
714
1148
|
cb(weights, "ffn_moe_weights_norm", il);
|
|
715
1149
|
|
|
@@ -720,6 +1154,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
720
1154
|
cb(weights, "ffn_moe_weights_scaled", il);
|
|
721
1155
|
}
|
|
722
1156
|
|
|
1157
|
+
//call early so that topk-moe can be used
|
|
1158
|
+
ggml_build_forward_expand(gf, weights);
|
|
1159
|
+
|
|
723
1160
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
|
724
1161
|
|
|
725
1162
|
if (weight_before_ffn) {
|
|
@@ -732,6 +1169,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
732
1169
|
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
733
1170
|
cb(up, "ffn_moe_up", il);
|
|
734
1171
|
|
|
1172
|
+
if (up_exps_b) {
|
|
1173
|
+
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
|
1174
|
+
cb(up, "ffn_moe_up_biased", il);
|
|
1175
|
+
}
|
|
1176
|
+
|
|
735
1177
|
ggml_tensor * experts = nullptr;
|
|
736
1178
|
if (gate_exps) {
|
|
737
1179
|
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
@@ -740,6 +1182,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
740
1182
|
cur = up;
|
|
741
1183
|
}
|
|
742
1184
|
|
|
1185
|
+
if (gate_exps_b) {
|
|
1186
|
+
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
|
1187
|
+
cb(cur, "ffn_moe_gate_biased", il);
|
|
1188
|
+
}
|
|
1189
|
+
|
|
743
1190
|
switch (type_op) {
|
|
744
1191
|
case LLM_FFN_SILU:
|
|
745
1192
|
if (gate_exps) {
|
|
@@ -757,6 +1204,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
757
1204
|
cur = ggml_gelu(ctx0, cur);
|
|
758
1205
|
cb(cur, "ffn_moe_gelu", il);
|
|
759
1206
|
} break;
|
|
1207
|
+
case LLM_FFN_SWIGLU_OAI_MOE:
|
|
1208
|
+
{
|
|
1209
|
+
// TODO: move to hparams?
|
|
1210
|
+
constexpr float alpha = 1.702f;
|
|
1211
|
+
constexpr float limit = 7.0f;
|
|
1212
|
+
cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
|
|
1213
|
+
cb(cur, "ffn_moe_swiglu_oai", il);
|
|
1214
|
+
} break;
|
|
1215
|
+
case LLM_FFN_RELU:
|
|
1216
|
+
if (gate_exps) {
|
|
1217
|
+
cur = ggml_reglu_split(ctx0, cur, up);
|
|
1218
|
+
cb(cur, "ffn_moe_reglu", il);
|
|
1219
|
+
} else {
|
|
1220
|
+
cur = ggml_relu(ctx0, cur);
|
|
1221
|
+
cb(cur, "ffn_moe_relu", il);
|
|
1222
|
+
} break;
|
|
1223
|
+
case LLM_FFN_RELU_SQR:
|
|
1224
|
+
if (gate_exps) {
|
|
1225
|
+
// TODO: add support for gated squared relu
|
|
1226
|
+
GGML_ABORT("fatal error: gated squared relu not implemented");
|
|
1227
|
+
} else {
|
|
1228
|
+
cur = ggml_relu(ctx0, cur);
|
|
1229
|
+
cur = ggml_sqr(ctx0, cur);
|
|
1230
|
+
cb(cur, "ffn_moe_relu_sqr", il);
|
|
1231
|
+
} break;
|
|
760
1232
|
default:
|
|
761
1233
|
GGML_ABORT("fatal error");
|
|
762
1234
|
}
|
|
@@ -764,25 +1236,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
764
1236
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
|
765
1237
|
cb(experts, "ffn_moe_down", il);
|
|
766
1238
|
|
|
1239
|
+
if (down_exps_b) {
|
|
1240
|
+
experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
|
|
1241
|
+
cb(experts, "ffn_moe_down_biased", il);
|
|
1242
|
+
}
|
|
1243
|
+
|
|
767
1244
|
if (!weight_before_ffn) {
|
|
768
1245
|
experts = ggml_mul(ctx0, experts, weights);
|
|
769
1246
|
cb(cur, "ffn_moe_weighted", il);
|
|
770
1247
|
}
|
|
771
1248
|
|
|
1249
|
+
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
|
|
1250
|
+
|
|
1251
|
+
assert(n_expert_used > 0);
|
|
1252
|
+
|
|
1253
|
+
// order the views before the adds
|
|
1254
|
+
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
|
|
1255
|
+
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
|
1256
|
+
|
|
1257
|
+
ggml_build_forward_expand(gf, cur_experts[i]);
|
|
1258
|
+
}
|
|
1259
|
+
|
|
772
1260
|
// aggregate experts
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
1261
|
+
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
|
|
1262
|
+
// to avoid potentially a large number of add nodes during warmup
|
|
1263
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
|
|
1264
|
+
ggml_tensor * moe_out = cur_experts[0];
|
|
777
1265
|
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
} else {
|
|
781
|
-
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
|
782
|
-
}
|
|
1266
|
+
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
|
|
1267
|
+
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
|
|
783
1268
|
}
|
|
784
1269
|
|
|
785
|
-
if (n_expert_used == 1) {
|
|
1270
|
+
if (hparams.n_expert_used == 1) {
|
|
786
1271
|
// avoid returning a non-contiguous tensor
|
|
787
1272
|
moe_out = ggml_cont(ctx0, moe_out);
|
|
788
1273
|
}
|
|
@@ -794,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
794
1279
|
|
|
795
1280
|
// input embeddings with optional lora
|
|
796
1281
|
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
797
|
-
const int64_t n_embd = hparams.
|
|
1282
|
+
const int64_t n_embd = hparams.n_embd_inp();
|
|
798
1283
|
|
|
799
1284
|
auto inp = std::make_unique<llm_graph_input_embd>();
|
|
800
1285
|
|
|
@@ -841,6 +1326,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
841
1326
|
|
|
842
1327
|
res->add_input(std::move(inp));
|
|
843
1328
|
|
|
1329
|
+
// make sure the produced embeddings are immediately materialized in the ggml graph
|
|
1330
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
|
|
1331
|
+
ggml_build_forward_expand(gf, cur);
|
|
1332
|
+
|
|
844
1333
|
return cur;
|
|
845
1334
|
}
|
|
846
1335
|
|
|
@@ -858,7 +1347,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|
|
858
1347
|
}
|
|
859
1348
|
|
|
860
1349
|
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
861
|
-
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
|
1350
|
+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
|
|
862
1351
|
|
|
863
1352
|
auto & cur = inp->attn_scale;
|
|
864
1353
|
|
|
@@ -906,7 +1395,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
|
906
1395
|
}
|
|
907
1396
|
|
|
908
1397
|
ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
909
|
-
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
|
|
1398
|
+
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
|
|
910
1399
|
|
|
911
1400
|
auto & cur = inp->cls;
|
|
912
1401
|
|
|
@@ -931,7 +1420,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|
|
931
1420
|
// return cur;
|
|
932
1421
|
//}
|
|
933
1422
|
|
|
934
|
-
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.
|
|
1423
|
+
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
|
935
1424
|
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
936
1425
|
|
|
937
1426
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
|
@@ -956,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
|
956
1445
|
}
|
|
957
1446
|
|
|
958
1447
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
959
|
-
const auto * mctx_cur = static_cast<const
|
|
1448
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
|
960
1449
|
|
|
961
1450
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
|
962
1451
|
|
|
@@ -987,56 +1476,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
|
987
1476
|
return pos_bias;
|
|
988
1477
|
}
|
|
989
1478
|
|
|
990
|
-
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
991
|
-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
992
|
-
|
|
993
|
-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
|
994
|
-
|
|
995
|
-
{
|
|
996
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
|
997
|
-
|
|
998
|
-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
|
999
|
-
|
|
1000
|
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1001
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1002
|
-
ggml_set_input(inp->self_kq_mask);
|
|
1003
|
-
|
|
1004
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1005
|
-
}
|
|
1006
|
-
|
|
1007
|
-
{
|
|
1008
|
-
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
|
1009
|
-
|
|
1010
|
-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1011
|
-
ggml_set_input(inp->s_copy);
|
|
1012
|
-
}
|
|
1013
|
-
|
|
1014
|
-
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
1015
|
-
}
|
|
1016
|
-
|
|
1017
1479
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1018
|
-
ggml_cgraph * gf,
|
|
1019
1480
|
ggml_tensor * q,
|
|
1020
1481
|
ggml_tensor * k,
|
|
1021
1482
|
ggml_tensor * v,
|
|
1022
1483
|
ggml_tensor * kq_b,
|
|
1023
1484
|
ggml_tensor * kq_mask,
|
|
1485
|
+
ggml_tensor * sinks,
|
|
1024
1486
|
ggml_tensor * v_mla,
|
|
1025
|
-
|
|
1487
|
+
float kq_scale,
|
|
1488
|
+
int il) const {
|
|
1026
1489
|
const bool v_trans = v->nb[1] > v->nb[2];
|
|
1027
1490
|
|
|
1491
|
+
// split the batch into streams if needed
|
|
1492
|
+
const auto n_stream = k->ne[3];
|
|
1493
|
+
|
|
1494
|
+
q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
|
|
1495
|
+
|
|
1028
1496
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
|
1029
1497
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
|
1030
1498
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
|
1031
1499
|
|
|
1032
|
-
const auto n_tokens = q->ne[1];
|
|
1033
|
-
const auto n_head = q->ne[2];
|
|
1034
|
-
const auto n_kv = k->ne[1];
|
|
1035
|
-
|
|
1036
1500
|
ggml_tensor * cur;
|
|
1037
1501
|
|
|
1038
|
-
|
|
1039
|
-
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
|
1502
|
+
if (cparams.flash_attn && kq_b == nullptr) {
|
|
1040
1503
|
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
|
1041
1504
|
|
|
1042
1505
|
if (v_trans) {
|
|
@@ -1054,8 +1517,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1054
1517
|
|
|
1055
1518
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
|
1056
1519
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
|
1520
|
+
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
|
1057
1521
|
|
|
1058
|
-
|
|
1522
|
+
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
|
1523
|
+
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
|
1059
1524
|
|
|
1060
1525
|
if (v_mla) {
|
|
1061
1526
|
#if 0
|
|
@@ -1068,14 +1533,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1068
1533
|
// The permutations are noops and only change how the tensor data is interpreted.
|
|
1069
1534
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1070
1535
|
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1536
|
+
cb(cur, "fattn_mla", il);
|
|
1071
1537
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1072
1538
|
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
|
1073
1539
|
#endif
|
|
1074
1540
|
}
|
|
1075
1541
|
|
|
1076
|
-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*
|
|
1542
|
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
|
1077
1543
|
} else {
|
|
1078
1544
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
1545
|
+
cb(kq, "kq", il);
|
|
1079
1546
|
|
|
1080
1547
|
// note: this op tends to require high floating point range
|
|
1081
1548
|
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
|
@@ -1083,42 +1550,54 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1083
1550
|
|
|
1084
1551
|
if (arch == LLM_ARCH_GROK) {
|
|
1085
1552
|
// need to do the following:
|
|
1086
|
-
// multiply by
|
|
1553
|
+
// multiply by attn_output_multiplier
|
|
1087
1554
|
// and then :
|
|
1088
1555
|
// kq = 30 * tanh(kq / 30)
|
|
1089
1556
|
// before the softmax below
|
|
1090
1557
|
|
|
1091
|
-
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq,
|
|
1092
|
-
kq
|
|
1558
|
+
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
|
|
1559
|
+
cb(kq, "kq_tanh", il);
|
|
1560
|
+
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
|
1561
|
+
cb(kq, "kq_scaled", il);
|
|
1093
1562
|
}
|
|
1094
1563
|
|
|
1095
1564
|
if (hparams.attn_soft_cap) {
|
|
1096
1565
|
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
|
1566
|
+
cb(kq, "kq_scaled_1", il);
|
|
1097
1567
|
kq = ggml_tanh (ctx0, kq);
|
|
1568
|
+
cb(kq, "kq_tanh", il);
|
|
1098
1569
|
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
|
1570
|
+
cb(kq, "kq_scaled_2", il);
|
|
1099
1571
|
}
|
|
1100
1572
|
|
|
1101
1573
|
if (kq_b) {
|
|
1102
1574
|
kq = ggml_add(ctx0, kq, kq_b);
|
|
1575
|
+
cb(kq, "kq_plus_kq_b", il);
|
|
1103
1576
|
}
|
|
1104
1577
|
|
|
1105
1578
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
|
1579
|
+
ggml_soft_max_add_sinks(kq, sinks);
|
|
1580
|
+
cb(kq, "kq_soft_max", il);
|
|
1106
1581
|
|
|
1107
1582
|
if (!v_trans) {
|
|
1108
1583
|
// note: avoid this branch
|
|
1109
1584
|
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
|
1585
|
+
cb(v, "v_cont", il);
|
|
1110
1586
|
}
|
|
1111
1587
|
|
|
1112
1588
|
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
|
1589
|
+
cb(kqv, "kqv", il);
|
|
1113
1590
|
|
|
1114
1591
|
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
|
1115
1592
|
if (v_mla) {
|
|
1116
1593
|
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
|
1594
|
+
cb(kqv, "kqv_mla", il);
|
|
1117
1595
|
}
|
|
1118
1596
|
|
|
1119
1597
|
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
1120
1598
|
|
|
1121
|
-
|
|
1599
|
+
// recombine streams
|
|
1600
|
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
|
1122
1601
|
|
|
1123
1602
|
if (!cparams.offload_kqv) {
|
|
1124
1603
|
// all nodes between the KV store and the attention output are run on the CPU
|
|
@@ -1135,24 +1614,33 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
1135
1614
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
|
1136
1615
|
|
|
1137
1616
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
1138
|
-
inp->
|
|
1139
|
-
|
|
1140
|
-
|
|
1617
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
|
1618
|
+
ggml_set_input(inp->self_kq_mask);
|
|
1619
|
+
|
|
1620
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1141
1621
|
|
|
1142
|
-
|
|
1622
|
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
1623
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
|
1624
|
+
ggml_set_input(inp->self_kq_mask_swa);
|
|
1625
|
+
|
|
1626
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1627
|
+
} else {
|
|
1628
|
+
inp->self_kq_mask_swa = nullptr;
|
|
1629
|
+
inp->self_kq_mask_swa_cnv = nullptr;
|
|
1630
|
+
}
|
|
1143
1631
|
|
|
1144
1632
|
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
|
1145
1633
|
}
|
|
1146
1634
|
|
|
1147
1635
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1148
1636
|
llm_graph_input_attn_no_cache * inp,
|
|
1149
|
-
ggml_cgraph * gf,
|
|
1150
1637
|
ggml_tensor * wo,
|
|
1151
1638
|
ggml_tensor * wo_b,
|
|
1152
1639
|
ggml_tensor * q_cur,
|
|
1153
1640
|
ggml_tensor * k_cur,
|
|
1154
1641
|
ggml_tensor * v_cur,
|
|
1155
1642
|
ggml_tensor * kq_b,
|
|
1643
|
+
ggml_tensor * sinks,
|
|
1156
1644
|
ggml_tensor * v_mla,
|
|
1157
1645
|
float kq_scale,
|
|
1158
1646
|
int il) const {
|
|
@@ -1164,13 +1652,20 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1164
1652
|
ggml_build_forward_expand(gf, k_cur);
|
|
1165
1653
|
ggml_build_forward_expand(gf, v_cur);
|
|
1166
1654
|
|
|
1167
|
-
const
|
|
1655
|
+
const bool is_swa = hparams.is_swa(il);
|
|
1656
|
+
|
|
1657
|
+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
1658
|
+
|
|
1659
|
+
// [TAG_NO_CACHE_PAD]
|
|
1660
|
+
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
|
1661
|
+
// but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
|
|
1662
|
+
//assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
|
|
1168
1663
|
|
|
1169
1664
|
ggml_tensor * q = q_cur;
|
|
1170
1665
|
ggml_tensor * k = k_cur;
|
|
1171
1666
|
ggml_tensor * v = v_cur;
|
|
1172
1667
|
|
|
1173
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1668
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
1174
1669
|
cb(cur, "kqv_out", il);
|
|
1175
1670
|
|
|
1176
1671
|
if (wo) {
|
|
@@ -1188,50 +1683,70 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1188
1683
|
return cur;
|
|
1189
1684
|
}
|
|
1190
1685
|
|
|
1191
|
-
|
|
1192
|
-
|
|
1686
|
+
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|
1687
|
+
ggml_context * ctx0,
|
|
1688
|
+
const llama_ubatch & ubatch,
|
|
1689
|
+
const llama_hparams & hparams,
|
|
1690
|
+
const llama_cparams & cparams,
|
|
1691
|
+
const llama_kv_cache_context * mctx_cur) {
|
|
1193
1692
|
|
|
1194
|
-
auto inp = std::make_unique<
|
|
1693
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
|
|
1195
1694
|
|
|
1196
1695
|
{
|
|
1197
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use
|
|
1696
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
|
1697
|
+
|
|
1698
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
|
1699
|
+
const auto n_tokens = ubatch.n_tokens;
|
|
1700
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1198
1701
|
|
|
1199
|
-
|
|
1702
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1703
|
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1200
1704
|
|
|
1201
|
-
inp->self_kq_mask =
|
|
1202
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1705
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
|
1203
1706
|
ggml_set_input(inp->self_kq_mask);
|
|
1204
1707
|
|
|
1205
1708
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1206
1709
|
}
|
|
1207
1710
|
|
|
1208
|
-
return
|
|
1711
|
+
return inp;
|
|
1712
|
+
}
|
|
1713
|
+
|
|
1714
|
+
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
|
|
1715
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
|
1716
|
+
|
|
1717
|
+
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
|
1718
|
+
|
|
1719
|
+
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
|
|
1209
1720
|
}
|
|
1210
1721
|
|
|
1211
1722
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1212
|
-
|
|
1213
|
-
ggml_cgraph * gf,
|
|
1723
|
+
llm_graph_input_attn_kv * inp,
|
|
1214
1724
|
ggml_tensor * wo,
|
|
1215
1725
|
ggml_tensor * wo_b,
|
|
1216
1726
|
ggml_tensor * q_cur,
|
|
1217
1727
|
ggml_tensor * k_cur,
|
|
1218
1728
|
ggml_tensor * v_cur,
|
|
1219
1729
|
ggml_tensor * kq_b,
|
|
1730
|
+
ggml_tensor * sinks,
|
|
1220
1731
|
ggml_tensor * v_mla,
|
|
1221
1732
|
float kq_scale,
|
|
1222
1733
|
int il) const {
|
|
1223
1734
|
// these nodes are added to the graph together so that they are not reordered
|
|
1224
1735
|
// by doing so, the number of splits in the graph is reduced
|
|
1736
|
+
// expand k later to enable rope fusion which directly writes into k-v cache
|
|
1225
1737
|
ggml_build_forward_expand(gf, q_cur);
|
|
1226
|
-
ggml_build_forward_expand(gf, k_cur);
|
|
1227
1738
|
ggml_build_forward_expand(gf, v_cur);
|
|
1739
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
1228
1740
|
|
|
1229
|
-
const auto * mctx_cur =
|
|
1741
|
+
const auto * mctx_cur = inp->mctx;
|
|
1230
1742
|
|
|
1231
1743
|
// store to KV cache
|
|
1232
1744
|
{
|
|
1233
|
-
|
|
1234
|
-
|
|
1745
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
1746
|
+
const auto & v_idxs = inp->get_v_idxs();
|
|
1747
|
+
|
|
1748
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1749
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1235
1750
|
}
|
|
1236
1751
|
|
|
1237
1752
|
const auto & kq_mask = inp->get_kq_mask();
|
|
@@ -1240,13 +1755,13 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1240
1755
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1241
1756
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1242
1757
|
|
|
1243
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1758
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
1244
1759
|
cb(cur, "kqv_out", il);
|
|
1245
1760
|
|
|
1246
1761
|
if (wo) {
|
|
1247
1762
|
cur = build_lora_mm(wo, cur);
|
|
1248
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
1249
|
-
// GLM4
|
|
1763
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
1764
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
|
1250
1765
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1251
1766
|
}
|
|
1252
1767
|
}
|
|
@@ -1259,14 +1774,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1259
1774
|
}
|
|
1260
1775
|
|
|
1261
1776
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1262
|
-
|
|
1263
|
-
ggml_cgraph * gf,
|
|
1777
|
+
llm_graph_input_attn_kv_iswa * inp,
|
|
1264
1778
|
ggml_tensor * wo,
|
|
1265
1779
|
ggml_tensor * wo_b,
|
|
1266
1780
|
ggml_tensor * q_cur,
|
|
1267
1781
|
ggml_tensor * k_cur,
|
|
1268
1782
|
ggml_tensor * v_cur,
|
|
1269
1783
|
ggml_tensor * kq_b,
|
|
1784
|
+
ggml_tensor * sinks,
|
|
1270
1785
|
ggml_tensor * v_mla,
|
|
1271
1786
|
float kq_scale,
|
|
1272
1787
|
int il) const {
|
|
@@ -1282,7 +1797,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1282
1797
|
ggml_build_forward_expand(gf, v_cur);
|
|
1283
1798
|
}
|
|
1284
1799
|
|
|
1285
|
-
const auto * mctx_iswa =
|
|
1800
|
+
const auto * mctx_iswa = inp->mctx;
|
|
1286
1801
|
|
|
1287
1802
|
const bool is_swa = hparams.is_swa(il);
|
|
1288
1803
|
|
|
@@ -1290,11 +1805,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1290
1805
|
|
|
1291
1806
|
// optionally store to KV cache
|
|
1292
1807
|
if (k_cur) {
|
|
1293
|
-
|
|
1808
|
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
|
1809
|
+
|
|
1810
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1294
1811
|
}
|
|
1295
1812
|
|
|
1296
1813
|
if (v_cur) {
|
|
1297
|
-
|
|
1814
|
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
|
1815
|
+
|
|
1816
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1298
1817
|
}
|
|
1299
1818
|
|
|
1300
1819
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
@@ -1303,7 +1822,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1303
1822
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1304
1823
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1305
1824
|
|
|
1306
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1825
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
1307
1826
|
cb(cur, "kqv_out", il);
|
|
1308
1827
|
|
|
1309
1828
|
if (wo) {
|
|
@@ -1326,7 +1845,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1326
1845
|
|
|
1327
1846
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1328
1847
|
|
|
1329
|
-
inp->cross_kq_mask =
|
|
1848
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
|
|
1330
1849
|
ggml_set_input(inp->cross_kq_mask);
|
|
1331
1850
|
|
|
1332
1851
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
@@ -1336,13 +1855,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1336
1855
|
|
|
1337
1856
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1338
1857
|
llm_graph_input_attn_cross * inp,
|
|
1339
|
-
ggml_cgraph * gf,
|
|
1340
1858
|
ggml_tensor * wo,
|
|
1341
1859
|
ggml_tensor * wo_b,
|
|
1342
1860
|
ggml_tensor * q_cur,
|
|
1343
1861
|
ggml_tensor * k_cur,
|
|
1344
1862
|
ggml_tensor * v_cur,
|
|
1345
1863
|
ggml_tensor * kq_b,
|
|
1864
|
+
ggml_tensor * sinks,
|
|
1346
1865
|
ggml_tensor * v_mla,
|
|
1347
1866
|
float kq_scale,
|
|
1348
1867
|
int il) const {
|
|
@@ -1358,7 +1877,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1358
1877
|
ggml_tensor * k = k_cur;
|
|
1359
1878
|
ggml_tensor * v = v_cur;
|
|
1360
1879
|
|
|
1361
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1880
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
|
1362
1881
|
cb(cur, "kqv_out", il);
|
|
1363
1882
|
|
|
1364
1883
|
if (wo) {
|
|
@@ -1376,171 +1895,131 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1376
1895
|
return cur;
|
|
1377
1896
|
}
|
|
1378
1897
|
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
ggml_tensor * q_cur,
|
|
1385
|
-
ggml_tensor * k_cur,
|
|
1386
|
-
ggml_tensor * v_cur,
|
|
1387
|
-
ggml_tensor * kq_b,
|
|
1388
|
-
ggml_tensor * v_mla,
|
|
1389
|
-
float kq_scale,
|
|
1390
|
-
int il) const {
|
|
1391
|
-
// these nodes are added to the graph together so that they are not reordered
|
|
1392
|
-
// by doing so, the number of splits in the graph is reduced
|
|
1393
|
-
ggml_build_forward_expand(gf, q_cur);
|
|
1394
|
-
ggml_build_forward_expand(gf, k_cur);
|
|
1395
|
-
ggml_build_forward_expand(gf, v_cur);
|
|
1396
|
-
|
|
1397
|
-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
|
1398
|
-
|
|
1399
|
-
// store to KV cache
|
|
1400
|
-
{
|
|
1401
|
-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
|
1402
|
-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
|
1403
|
-
}
|
|
1404
|
-
|
|
1405
|
-
const auto & kq_mask = inp->get_kq_mask();
|
|
1406
|
-
|
|
1407
|
-
ggml_tensor * q = q_cur;
|
|
1408
|
-
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1409
|
-
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1410
|
-
|
|
1411
|
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1412
|
-
cb(cur, "kqv_out", il);
|
|
1413
|
-
|
|
1414
|
-
if (wo) {
|
|
1415
|
-
cur = build_lora_mm(wo, cur);
|
|
1416
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
1417
|
-
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
1418
|
-
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1419
|
-
}
|
|
1420
|
-
}
|
|
1421
|
-
|
|
1422
|
-
if (wo_b) {
|
|
1423
|
-
cur = ggml_add(ctx0, cur, wo_b);
|
|
1424
|
-
}
|
|
1425
|
-
|
|
1426
|
-
return cur;
|
|
1427
|
-
}
|
|
1898
|
+
// TODO: maybe separate the inner implementation into a separate function
|
|
1899
|
+
// like with the non-sliding window equivalent
|
|
1900
|
+
// once sliding-window hybrid caches are a thing.
|
|
1901
|
+
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
|
|
1902
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
|
|
1428
1903
|
|
|
1429
|
-
|
|
1430
|
-
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
|
1904
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
|
1431
1905
|
|
|
1432
|
-
auto
|
|
1906
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1433
1907
|
|
|
1434
1908
|
{
|
|
1435
1909
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1436
1910
|
|
|
1437
|
-
inp->
|
|
1438
|
-
|
|
1911
|
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1912
|
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1913
|
+
|
|
1914
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
|
1439
1915
|
ggml_set_input(inp->self_kq_mask);
|
|
1916
|
+
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
|
1440
1917
|
|
|
1441
1918
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1919
|
+
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
|
1442
1920
|
}
|
|
1443
1921
|
|
|
1444
1922
|
{
|
|
1445
|
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use
|
|
1923
|
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
|
1446
1924
|
|
|
1447
1925
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1448
1926
|
|
|
1449
|
-
inp->
|
|
1450
|
-
|
|
1927
|
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1928
|
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1929
|
+
|
|
1930
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
|
1451
1931
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
1932
|
+
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
|
1452
1933
|
|
|
1453
1934
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1935
|
+
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
|
1454
1936
|
}
|
|
1455
1937
|
|
|
1456
|
-
return (
|
|
1938
|
+
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
|
1457
1939
|
}
|
|
1458
1940
|
|
|
1459
1941
|
ggml_tensor * llm_graph_context::build_rs(
|
|
1460
|
-
ggml_cgraph * gf,
|
|
1461
1942
|
ggml_tensor * s,
|
|
1462
|
-
ggml_tensor *
|
|
1943
|
+
ggml_tensor * state_copy_main,
|
|
1944
|
+
ggml_tensor * state_copy_extra,
|
|
1463
1945
|
int32_t state_size,
|
|
1464
1946
|
int32_t n_seqs,
|
|
1465
|
-
uint32_t
|
|
1466
|
-
uint32_t
|
|
1467
|
-
uint32_t
|
|
1947
|
+
uint32_t n_rs,
|
|
1948
|
+
uint32_t rs_head,
|
|
1949
|
+
uint32_t rs_size,
|
|
1468
1950
|
int32_t rs_zero,
|
|
1469
|
-
|
|
1951
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1470
1952
|
|
|
1471
|
-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size,
|
|
1953
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
|
|
1472
1954
|
|
|
1473
1955
|
// Clear a single state which will then be copied to the other cleared states.
|
|
1474
1956
|
// Note that this is a no-op when the view is zero-sized.
|
|
1475
1957
|
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
|
1476
1958
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1477
1959
|
|
|
1478
|
-
|
|
1960
|
+
// copy states
|
|
1961
|
+
// NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
|
|
1962
|
+
// {state_size, rs_size} -> {state_size, n_seqs}
|
|
1963
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
|
|
1964
|
+
ggml_build_forward_expand(gf, output_states);
|
|
1479
1965
|
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
1483
|
-
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1484
|
-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1485
|
-
ggml_build_forward_expand(gf, output_states);
|
|
1486
|
-
} else {
|
|
1487
|
-
// FIXME: make the gathering operation happen before the copy below
|
|
1488
|
-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
1489
|
-
output_states = states;
|
|
1490
|
-
}
|
|
1491
|
-
|
|
1492
|
-
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
|
1493
|
-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
1966
|
+
// copy extra states which won't be changed further (between n_seqs and n_rs)
|
|
1967
|
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
|
|
1494
1968
|
ggml_build_forward_expand(gf,
|
|
1495
1969
|
ggml_cpy(ctx0,
|
|
1496
1970
|
states_extra,
|
|
1497
|
-
ggml_view_1d(ctx0, s, state_size*(
|
|
1971
|
+
ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
|
|
1498
1972
|
|
|
1499
1973
|
return output_states;
|
|
1500
1974
|
}
|
|
1501
1975
|
|
|
1502
|
-
llm_graph_input_rs
|
|
1503
|
-
|
|
1976
|
+
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|
1977
|
+
ggml_context * ctx0,
|
|
1978
|
+
const llama_ubatch & ubatch,
|
|
1979
|
+
const llama_memory_recurrent_context * mctx_cur) {
|
|
1504
1980
|
|
|
1505
1981
|
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
|
1506
1982
|
|
|
1507
|
-
const
|
|
1983
|
+
const int64_t n_rs = mctx_cur->get_n_rs();
|
|
1984
|
+
const int64_t n_seqs = ubatch.n_seqs;
|
|
1508
1985
|
|
|
1509
1986
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1510
1987
|
ggml_set_input(inp->s_copy);
|
|
1511
1988
|
|
|
1512
|
-
|
|
1989
|
+
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
|
1990
|
+
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
|
1991
|
+
|
|
1992
|
+
inp->head = mctx_cur->get_head();
|
|
1993
|
+
inp->rs_z = mctx_cur->get_rs_z();
|
|
1994
|
+
|
|
1995
|
+
return inp;
|
|
1513
1996
|
}
|
|
1514
1997
|
|
|
1515
|
-
|
|
1516
|
-
llm_graph_input_rs * inp,
|
|
1517
|
-
ggml_cgraph * gf,
|
|
1518
|
-
ggml_tensor * s,
|
|
1519
|
-
int32_t state_size,
|
|
1520
|
-
int32_t n_seqs,
|
|
1521
|
-
bool avoid_copies) const {
|
|
1998
|
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
1522
1999
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1523
2000
|
|
|
1524
|
-
|
|
2001
|
+
auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
|
|
2002
|
+
|
|
2003
|
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
|
1525
2004
|
}
|
|
1526
2005
|
|
|
1527
2006
|
ggml_tensor * llm_graph_context::build_rs(
|
|
1528
|
-
|
|
1529
|
-
ggml_cgraph * gf,
|
|
2007
|
+
llm_graph_input_rs * inp,
|
|
1530
2008
|
ggml_tensor * s,
|
|
1531
2009
|
int32_t state_size,
|
|
1532
2010
|
int32_t n_seqs,
|
|
1533
|
-
|
|
1534
|
-
const auto *
|
|
2011
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
2012
|
+
const auto * kv_state = inp->mctx;
|
|
1535
2013
|
|
|
1536
|
-
return build_rs(
|
|
2014
|
+
return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
|
|
2015
|
+
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
|
|
2016
|
+
get_state_rows);
|
|
1537
2017
|
}
|
|
1538
2018
|
|
|
1539
2019
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
1540
2020
|
llm_graph_input_rs * inp,
|
|
1541
|
-
ggml_cgraph * gf,
|
|
1542
2021
|
const llama_ubatch & ubatch,
|
|
1543
|
-
|
|
2022
|
+
int il) const {
|
|
1544
2023
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1545
2024
|
|
|
1546
2025
|
const auto token_shift_count = hparams.token_shift_count;
|
|
@@ -1550,7 +2029,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
1550
2029
|
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
|
1551
2030
|
|
|
1552
2031
|
ggml_tensor * token_shift = build_rs(
|
|
1553
|
-
inp,
|
|
2032
|
+
inp, token_shift_all,
|
|
1554
2033
|
hparams.n_embd_r(), n_seqs);
|
|
1555
2034
|
|
|
1556
2035
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
@@ -1578,8 +2057,39 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1578
2057
|
);
|
|
1579
2058
|
}
|
|
1580
2059
|
|
|
2060
|
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
2061
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
2062
|
+
|
|
2063
|
+
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
|
2064
|
+
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
2065
|
+
|
|
2066
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
2067
|
+
|
|
2068
|
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
2069
|
+
}
|
|
2070
|
+
|
|
2071
|
+
void llm_graph_context::build_dense_out(
|
|
2072
|
+
ggml_tensor * dense_2,
|
|
2073
|
+
ggml_tensor * dense_3) const {
|
|
2074
|
+
if (!cparams.embeddings || !(dense_2 || dense_3)) {
|
|
2075
|
+
return;
|
|
2076
|
+
}
|
|
2077
|
+
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
|
2078
|
+
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
|
|
2079
|
+
|
|
2080
|
+
if (dense_2) {
|
|
2081
|
+
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
|
2082
|
+
}
|
|
2083
|
+
if (dense_3) {
|
|
2084
|
+
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
|
2085
|
+
}
|
|
2086
|
+
cb(cur, "result_embd_pooled", -1);
|
|
2087
|
+
res->t_embd_pooled = cur;
|
|
2088
|
+
ggml_build_forward_expand(gf, cur);
|
|
2089
|
+
}
|
|
2090
|
+
|
|
2091
|
+
|
|
1581
2092
|
void llm_graph_context::build_pooling(
|
|
1582
|
-
ggml_cgraph * gf,
|
|
1583
2093
|
ggml_tensor * cls,
|
|
1584
2094
|
ggml_tensor * cls_b,
|
|
1585
2095
|
ggml_tensor * cls_out,
|
|
@@ -1623,34 +2133,32 @@ void llm_graph_context::build_pooling(
|
|
|
1623
2133
|
case LLAMA_POOLING_TYPE_RANK:
|
|
1624
2134
|
{
|
|
1625
2135
|
ggml_tensor * inp_cls = build_inp_cls();
|
|
1626
|
-
|
|
2136
|
+
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
|
1627
2137
|
|
|
2138
|
+
// classification head
|
|
2139
|
+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
1628
2140
|
if (cls) {
|
|
1629
|
-
|
|
1630
|
-
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
1631
|
-
cur = ggml_mul_mat(ctx0, cls, inp);
|
|
2141
|
+
cur = ggml_mul_mat(ctx0, cls, cur);
|
|
1632
2142
|
if (cls_b) {
|
|
1633
2143
|
cur = ggml_add(ctx0, cur, cls_b);
|
|
1634
2144
|
}
|
|
1635
2145
|
cur = ggml_tanh(ctx0, cur);
|
|
2146
|
+
}
|
|
1636
2147
|
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
}
|
|
1644
|
-
}
|
|
1645
|
-
} else if (cls_out) {
|
|
1646
|
-
// Single layer classification head (direct projection)
|
|
1647
|
-
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
|
1648
|
-
cur = ggml_mul_mat(ctx0, cls_out, inp);
|
|
2148
|
+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
2149
|
+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
|
2150
|
+
// Single layer classification head (direct projection)
|
|
2151
|
+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
|
2152
|
+
if (cls_out) {
|
|
2153
|
+
cur = ggml_mul_mat(ctx0, cls_out, cur);
|
|
1649
2154
|
if (cls_out_b) {
|
|
1650
2155
|
cur = ggml_add(ctx0, cur, cls_out_b);
|
|
1651
2156
|
}
|
|
1652
|
-
}
|
|
1653
|
-
|
|
2157
|
+
}
|
|
2158
|
+
|
|
2159
|
+
// softmax for qwen3 reranker
|
|
2160
|
+
if (arch == LLM_ARCH_QWEN3) {
|
|
2161
|
+
cur = ggml_soft_max(ctx0, cur);
|
|
1654
2162
|
}
|
|
1655
2163
|
} break;
|
|
1656
2164
|
default:
|
|
@@ -1665,6 +2173,87 @@ void llm_graph_context::build_pooling(
|
|
|
1665
2173
|
ggml_build_forward_expand(gf, cur);
|
|
1666
2174
|
}
|
|
1667
2175
|
|
|
2176
|
+
void llm_graph_context::build_sampling() const {
|
|
2177
|
+
if (samplers.empty() || !res->t_logits) {
|
|
2178
|
+
return;
|
|
2179
|
+
}
|
|
2180
|
+
|
|
2181
|
+
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
|
2182
|
+
res->add_input(std::move(inp_sampling));
|
|
2183
|
+
|
|
2184
|
+
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
|
2185
|
+
int32_t logit_row_idx = 0;
|
|
2186
|
+
|
|
2187
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
2188
|
+
if (ubatch.output[i]) {
|
|
2189
|
+
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
2190
|
+
seq_to_logit_row[seq_id] = logit_row_idx;
|
|
2191
|
+
logit_row_idx++;
|
|
2192
|
+
}
|
|
2193
|
+
}
|
|
2194
|
+
|
|
2195
|
+
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
|
|
2196
|
+
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
|
2197
|
+
|
|
2198
|
+
// add a dummy row of logits
|
|
2199
|
+
// this trick makes the graph static, regardless of which samplers are activated
|
|
2200
|
+
// this is important in order to minimize graph reallocations
|
|
2201
|
+
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
|
|
2202
|
+
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
|
2203
|
+
|
|
2204
|
+
for (const auto & [seq_id, sampler] : samplers) {
|
|
2205
|
+
const auto it = seq_to_logit_row.find(seq_id);
|
|
2206
|
+
|
|
2207
|
+
// inactive samplers always work on the first row
|
|
2208
|
+
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
|
|
2209
|
+
|
|
2210
|
+
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
|
2211
|
+
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
|
2212
|
+
|
|
2213
|
+
struct llama_sampler_data data = {
|
|
2214
|
+
/*.logits =*/ logits_seq,
|
|
2215
|
+
/*.probs =*/ nullptr,
|
|
2216
|
+
/*.sampled =*/ nullptr,
|
|
2217
|
+
/*.candidates =*/ nullptr,
|
|
2218
|
+
};
|
|
2219
|
+
|
|
2220
|
+
assert(sampler->iface->backend_apply);
|
|
2221
|
+
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
|
2222
|
+
|
|
2223
|
+
if (data.sampled != nullptr) {
|
|
2224
|
+
res->t_sampled[seq_id] = data.sampled;
|
|
2225
|
+
ggml_build_forward_expand(gf, data.sampled);
|
|
2226
|
+
}
|
|
2227
|
+
|
|
2228
|
+
if (data.probs != nullptr) {
|
|
2229
|
+
res->t_sampled_probs[seq_id] = data.probs;
|
|
2230
|
+
ggml_build_forward_expand(gf, data.probs);
|
|
2231
|
+
}
|
|
2232
|
+
|
|
2233
|
+
if (data.logits != nullptr) {
|
|
2234
|
+
res->t_sampled_logits[seq_id] = data.logits;
|
|
2235
|
+
ggml_build_forward_expand(gf, data.logits);
|
|
2236
|
+
}
|
|
2237
|
+
|
|
2238
|
+
if (data.candidates != nullptr) {
|
|
2239
|
+
res->t_candidates[seq_id] = data.candidates;
|
|
2240
|
+
ggml_build_forward_expand(gf, data.candidates);
|
|
2241
|
+
}
|
|
2242
|
+
}
|
|
2243
|
+
|
|
2244
|
+
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
|
2245
|
+
/*
|
|
2246
|
+
for (const auto & [seq_id, sampler] : samplers) {
|
|
2247
|
+
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
|
2248
|
+
ggml_tensor * selected_token = it->second;
|
|
2249
|
+
if (selected_token != nullptr) {
|
|
2250
|
+
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
|
2251
|
+
}
|
|
2252
|
+
}
|
|
2253
|
+
}
|
|
2254
|
+
*/
|
|
2255
|
+
}
|
|
2256
|
+
|
|
1668
2257
|
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
|
1669
2258
|
// TODO move to hparams if a T5 variant appears that uses a different value
|
|
1670
2259
|
const int64_t max_distance = 128;
|
|
@@ -1680,7 +2269,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
|
|
|
1680
2269
|
|
|
1681
2270
|
if (bidirectional) {
|
|
1682
2271
|
relative_bucket += (relative_position > 0) * n_buckets;
|
|
1683
|
-
relative_position = abs(relative_position);
|
|
2272
|
+
relative_position = std::abs(relative_position);
|
|
1684
2273
|
} else {
|
|
1685
2274
|
relative_position = -std::min<int32_t>(relative_position, 0);
|
|
1686
2275
|
}
|