whispercpp 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +59 -27
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +154 -35
- data/ext/sources/examples/addon.node/index.js +10 -5
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +29 -18
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +7 -4
- data/ext/sources/examples/command/command.cpp +58 -32
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +21 -17
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +193 -35
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +10 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
- data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
- data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
- data/ext/sources/examples/talk-llama/llama-context.h +68 -32
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
- data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
- data/ext/sources/examples/talk-llama/llama-model.h +87 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
- data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -17
- data/ext/sources/examples/talk-llama/llama.h +176 -151
- data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +106 -33
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +18 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +365 -21
- data/ext/sources/ggml/src/CMakeLists.txt +98 -25
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
- data/ext/sources/ggml/src/ggml-common.h +21 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
- data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
- data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
- data/ext/sources/ggml/src/ggml-impl.h +229 -175
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +117 -24
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +802 -142
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +32 -4
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +241 -215
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +57 -2
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +75 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/{tests → test}/test_params.rb +8 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +246 -191
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -3,7 +3,11 @@
|
|
3
3
|
#include "llama-impl.h"
|
4
4
|
#include "llama-batch.h"
|
5
5
|
#include "llama-cparams.h"
|
6
|
+
|
6
7
|
#include "llama-kv-cache.h"
|
8
|
+
#include "llama-kv-cache-iswa.h"
|
9
|
+
#include "llama-memory-hybrid.h"
|
10
|
+
#include "llama-memory-recurrent.h"
|
7
11
|
|
8
12
|
#include <cassert>
|
9
13
|
#include <cmath>
|
@@ -24,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
24
28
|
}
|
25
29
|
}
|
26
30
|
|
31
|
+
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
32
|
+
bool res = true;
|
33
|
+
|
34
|
+
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
35
|
+
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
36
|
+
|
37
|
+
return res;
|
38
|
+
}
|
39
|
+
|
27
40
|
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
28
41
|
if (ubatch->pos && pos) {
|
29
42
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -46,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
46
59
|
}
|
47
60
|
}
|
48
61
|
|
62
|
+
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
63
|
+
bool res = true;
|
64
|
+
|
65
|
+
res &= pos->ne[0] == params.ubatch.n_tokens;
|
66
|
+
|
67
|
+
return res;
|
68
|
+
}
|
69
|
+
|
49
70
|
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
50
71
|
if (ubatch->pos && attn_scale) {
|
51
72
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -67,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
67
88
|
const int64_t n_tokens = ubatch->n_tokens;
|
68
89
|
|
69
90
|
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
70
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
91
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
71
92
|
|
72
93
|
int32_t * data = (int32_t *) pos_bucket->data;
|
73
94
|
|
@@ -83,182 +104,149 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
83
104
|
|
84
105
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
85
106
|
if (pos_bucket) {
|
86
|
-
|
107
|
+
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
87
108
|
}
|
88
109
|
}
|
89
110
|
|
90
111
|
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
91
|
-
|
92
|
-
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
|
112
|
+
GGML_ASSERT(out_ids);
|
93
113
|
|
94
|
-
|
95
|
-
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
|
96
|
-
} else {
|
97
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
114
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
98
115
|
|
99
|
-
|
100
|
-
|
116
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
117
|
+
int32_t * data = (int32_t *) out_ids->data;
|
101
118
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
data[0] = n_tokens - 1;
|
118
|
-
} else {
|
119
|
-
GGML_ASSERT(n_outputs == 0);
|
120
|
-
}
|
119
|
+
if (n_outputs == n_tokens) {
|
120
|
+
for (int i = 0; i < n_tokens; ++i) {
|
121
|
+
data[i] = i;
|
122
|
+
}
|
123
|
+
|
124
|
+
return;
|
125
|
+
}
|
126
|
+
|
127
|
+
GGML_ASSERT(ubatch->output);
|
128
|
+
|
129
|
+
int n_outputs = 0;
|
130
|
+
|
131
|
+
for (int i = 0; i < n_tokens; ++i) {
|
132
|
+
if (ubatch->output[i]) {
|
133
|
+
data[n_outputs++] = i;
|
121
134
|
}
|
122
135
|
}
|
123
136
|
}
|
124
137
|
|
138
|
+
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
139
|
+
bool res = true;
|
140
|
+
|
141
|
+
res &= n_outputs == params.n_outputs;
|
142
|
+
|
143
|
+
return res;
|
144
|
+
}
|
145
|
+
|
125
146
|
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
126
147
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
127
148
|
const int64_t n_tokens = ubatch->n_tokens;
|
128
149
|
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
129
|
-
const int64_t
|
150
|
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
130
151
|
|
131
152
|
GGML_ASSERT(mean);
|
132
153
|
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
133
154
|
|
134
155
|
float * data = (float *) mean->data;
|
135
|
-
memset(mean->data, 0, n_tokens
|
136
|
-
|
137
|
-
std::vector<uint64_t> sum(n_tokens, 0);
|
138
|
-
|
139
|
-
for (int s = 0; s < n_seqs; ++s) {
|
140
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
156
|
+
memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
|
141
157
|
|
142
|
-
|
143
|
-
|
158
|
+
std::vector<uint64_t> sums(n_seqs_unq, 0);
|
159
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
160
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
161
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
162
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
144
163
|
|
145
|
-
|
146
|
-
}
|
147
|
-
|
148
|
-
std::vector<float> div(n_tokens, 0.0f);
|
149
|
-
for (int i = 0; i < n_tokens; ++i) {
|
150
|
-
const uint64_t s = sum[i];
|
151
|
-
if (s > 0) {
|
152
|
-
div[i] = 1.0f/float(s);
|
164
|
+
sums[seq_idx] += ubatch->n_seq_tokens;
|
153
165
|
}
|
154
166
|
}
|
155
167
|
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
168
|
+
std::vector<float> div(n_seqs_unq, 0.0f);
|
169
|
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
170
|
+
const uint64_t sum = sums[s];
|
171
|
+
if (sum > 0) {
|
172
|
+
div[s] = 1.0f/float(sum);
|
161
173
|
}
|
162
174
|
}
|
163
|
-
}
|
164
|
-
}
|
165
|
-
|
166
|
-
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
167
|
-
if (cparams.embeddings && (
|
168
|
-
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
169
|
-
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
|
170
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
171
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
172
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
173
175
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
179
|
-
|
180
|
-
for (int s = 0; s < n_seqs; ++s) {
|
181
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
182
|
-
|
183
|
-
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
184
|
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
176
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
177
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
178
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
179
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
185
180
|
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
if (pos == 0) {
|
190
|
-
data[seq_id] = s*n_seq_tokens + i;
|
181
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
182
|
+
data[seq_idx*n_tokens + i + j] = div[seq_idx];
|
191
183
|
}
|
192
184
|
}
|
193
185
|
}
|
194
186
|
}
|
187
|
+
}
|
195
188
|
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
189
|
+
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
190
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
191
|
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
200
192
|
|
193
|
+
if (cparams.embeddings && (
|
194
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
195
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
|
196
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
|
197
|
+
)) {
|
201
198
|
GGML_ASSERT(cls);
|
202
199
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
203
200
|
|
204
201
|
uint32_t * data = (uint32_t *) cls->data;
|
205
|
-
memset(cls->data, 0,
|
202
|
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
206
203
|
|
207
|
-
std::vector<int>
|
208
|
-
std::vector<int>
|
204
|
+
std::vector<int> target_pos(n_seqs_unq, -1);
|
205
|
+
std::vector<int> target_row(n_seqs_unq, -1);
|
209
206
|
|
210
|
-
|
211
|
-
|
207
|
+
const bool last = (
|
208
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
209
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
210
|
+
);
|
212
211
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
for (int
|
217
|
-
const
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
212
|
+
for (int i = 0; i < n_tokens; ++i) {
|
213
|
+
const llama_pos pos = ubatch->pos[i];
|
214
|
+
|
215
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
216
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
217
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
218
|
+
|
219
|
+
if (
|
220
|
+
(target_pos[seq_idx] == -1) ||
|
221
|
+
( last && pos >= target_pos[seq_idx]) ||
|
222
|
+
(!last && pos < target_pos[seq_idx])
|
223
|
+
) {
|
224
|
+
target_pos[seq_idx] = pos;
|
225
|
+
target_row[seq_idx] = i;
|
222
226
|
}
|
223
227
|
}
|
224
228
|
}
|
225
229
|
|
226
|
-
for (int
|
227
|
-
if (
|
228
|
-
data[
|
230
|
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
231
|
+
if (target_row[s] >= 0) {
|
232
|
+
data[s] = target_row[s];
|
229
233
|
}
|
230
234
|
}
|
231
235
|
}
|
232
236
|
}
|
233
237
|
|
234
|
-
void
|
238
|
+
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
235
239
|
GGML_UNUSED(ubatch);
|
236
240
|
|
237
|
-
const int64_t
|
241
|
+
const int64_t n_rs = mctx->get_n_rs();
|
238
242
|
|
239
243
|
if (s_copy) {
|
240
244
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
241
245
|
int32_t * data = (int32_t *) s_copy->data;
|
242
246
|
|
243
247
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
244
|
-
for (uint32_t i = 0; i <
|
245
|
-
data[i] =
|
246
|
-
}
|
247
|
-
}
|
248
|
-
}
|
249
|
-
|
250
|
-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
251
|
-
GGML_UNUSED(ubatch);
|
252
|
-
|
253
|
-
const int64_t n_kv = kv_self->n;
|
254
|
-
|
255
|
-
if (s_mask) {
|
256
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
257
|
-
float * data = (float *) s_mask->data;
|
258
|
-
|
259
|
-
// clear unused states
|
260
|
-
for (int i = 0; i < n_kv; ++i) {
|
261
|
-
data[i] = kv_self->s_mask(i);
|
248
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
249
|
+
data[i] = mctx->s_copy(i);
|
262
250
|
}
|
263
251
|
}
|
264
252
|
}
|
@@ -273,142 +261,270 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
273
261
|
}
|
274
262
|
}
|
275
263
|
|
264
|
+
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
265
|
+
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
266
|
+
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
267
|
+
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
268
|
+
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
269
|
+
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
270
|
+
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
271
|
+
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
272
|
+
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
273
|
+
|
274
|
+
LLAMA_LOG_DEBUG(" ");
|
275
|
+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
276
|
+
LLAMA_LOG_DEBUG("%2d", j);
|
277
|
+
}
|
278
|
+
LLAMA_LOG_DEBUG("\n");
|
279
|
+
|
280
|
+
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
|
281
|
+
LLAMA_LOG_DEBUG(" %2d ", i);
|
282
|
+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
283
|
+
float val = data[i * n_kv + j];
|
284
|
+
if (val == -INFINITY) {
|
285
|
+
LLAMA_LOG_DEBUG(" ∞");
|
286
|
+
} else {
|
287
|
+
LLAMA_LOG_DEBUG(" 0");
|
288
|
+
}
|
289
|
+
}
|
290
|
+
LLAMA_LOG_DEBUG("\n");
|
291
|
+
}
|
292
|
+
}
|
293
|
+
|
276
294
|
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
301
|
-
if (hparams.use_alibi) {
|
302
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
303
|
-
} else {
|
304
|
-
f = 0.0f;
|
305
|
-
}
|
306
|
-
break;
|
307
|
-
}
|
308
|
-
}
|
309
|
-
|
310
|
-
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
311
|
-
}
|
312
|
-
}
|
295
|
+
const int64_t n_kv = ubatch->n_tokens;
|
296
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
297
|
+
|
298
|
+
GGML_ASSERT(kq_mask);
|
299
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
300
|
+
|
301
|
+
float * data = (float *) kq_mask->data;
|
302
|
+
|
303
|
+
// [TAG_NO_CACHE_ISWA]
|
304
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
305
|
+
|
306
|
+
for (int h = 0; h < 1; ++h) {
|
307
|
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
308
|
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
309
|
+
|
310
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
311
|
+
float f = -INFINITY;
|
312
|
+
|
313
|
+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
314
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
315
|
+
|
316
|
+
if (s0 != s1) {
|
317
|
+
continue; // skip different sequences
|
313
318
|
}
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
331
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
332
|
-
|
333
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
334
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
335
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
336
|
-
float f = -INFINITY;
|
337
|
-
|
338
|
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
339
|
-
if (ubatch->seq_id[s0][s] == seq_id) {
|
340
|
-
if (hparams.use_alibi) {
|
341
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
342
|
-
} else {
|
343
|
-
f = 0.0f;
|
344
|
-
}
|
345
|
-
break;
|
346
|
-
}
|
347
|
-
}
|
348
|
-
|
349
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
350
|
-
}
|
351
|
-
}
|
352
|
-
|
353
|
-
for (int i = n_tokens; i < n_stride; ++i) {
|
354
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
355
|
-
}
|
319
|
+
|
320
|
+
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
321
|
+
continue; // skip future tokens for causal attention
|
322
|
+
}
|
323
|
+
|
324
|
+
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
325
|
+
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
326
|
+
// continue; // skip masked tokens for SWA
|
327
|
+
//}
|
328
|
+
|
329
|
+
// TODO: reimplement this like in llama_kv_cache_unified
|
330
|
+
if (hparams.use_alibi) {
|
331
|
+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
332
|
+
} else {
|
333
|
+
f = 0.0f;
|
356
334
|
}
|
357
335
|
}
|
336
|
+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
358
337
|
}
|
359
338
|
}
|
360
339
|
}
|
340
|
+
if (debug) {
|
341
|
+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
342
|
+
}
|
361
343
|
}
|
362
344
|
|
363
|
-
void
|
364
|
-
|
365
|
-
|
366
|
-
|
345
|
+
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
346
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
347
|
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
348
|
+
|
349
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
367
350
|
}
|
368
351
|
|
369
|
-
|
370
|
-
|
371
|
-
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
372
|
-
}
|
352
|
+
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
353
|
+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
373
354
|
|
374
|
-
|
375
|
-
|
376
|
-
|
355
|
+
this->mctx = mctx;
|
356
|
+
|
357
|
+
bool res = true;
|
358
|
+
|
359
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
360
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
361
|
+
|
362
|
+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
363
|
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
364
|
+
|
365
|
+
return res;
|
366
|
+
}
|
367
|
+
|
368
|
+
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
369
|
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
370
|
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
371
|
+
|
372
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
373
|
+
|
374
|
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
375
|
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
376
|
+
|
377
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
378
|
+
}
|
379
|
+
|
380
|
+
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
381
|
+
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
|
382
|
+
|
383
|
+
this->mctx = mctx;
|
384
|
+
|
385
|
+
bool res = true;
|
386
|
+
|
387
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
388
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
389
|
+
|
390
|
+
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
391
|
+
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
392
|
+
|
393
|
+
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
394
|
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
395
|
+
|
396
|
+
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
397
|
+
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
398
|
+
|
399
|
+
return res;
|
377
400
|
}
|
378
401
|
|
379
402
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
380
|
-
|
381
|
-
const int64_t n_enc = cross_kq_mask->ne[0];
|
382
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
403
|
+
GGML_ASSERT(cross_kq_mask);
|
383
404
|
|
384
|
-
|
385
|
-
|
405
|
+
const int64_t n_enc = cross_kq_mask->ne[0];
|
406
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
386
407
|
|
387
|
-
|
408
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
409
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
388
410
|
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
411
|
+
float * data = (float *) cross_kq_mask->data;
|
412
|
+
|
413
|
+
for (int h = 0; h < 1; ++h) {
|
414
|
+
for (int i = 0; i < n_tokens; ++i) {
|
415
|
+
for (int j = 0; j < n_enc; ++j) {
|
416
|
+
float f = -INFINITY;
|
417
|
+
|
418
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
419
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
420
|
+
|
421
|
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
422
|
+
f = 0.0f;
|
398
423
|
}
|
399
|
-
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
400
424
|
}
|
425
|
+
|
426
|
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
401
427
|
}
|
428
|
+
}
|
402
429
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
}
|
430
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
431
|
+
for (int j = 0; j < n_enc; ++j) {
|
432
|
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
407
433
|
}
|
408
434
|
}
|
409
435
|
}
|
410
436
|
}
|
411
437
|
|
438
|
+
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
439
|
+
inp_attn->set_input(ubatch);
|
440
|
+
inp_rs->set_input(ubatch);
|
441
|
+
}
|
442
|
+
|
443
|
+
//
|
444
|
+
// llm_graph_result
|
445
|
+
//
|
446
|
+
|
447
|
+
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
|
448
|
+
reset();
|
449
|
+
|
450
|
+
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
|
451
|
+
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
|
452
|
+
}
|
453
|
+
|
454
|
+
int64_t llm_graph_result::get_max_nodes() const {
|
455
|
+
return max_nodes;
|
456
|
+
}
|
457
|
+
|
458
|
+
void llm_graph_result::reset() {
|
459
|
+
t_tokens = nullptr;
|
460
|
+
t_logits = nullptr;
|
461
|
+
t_embd = nullptr;
|
462
|
+
t_embd_pooled = nullptr;
|
463
|
+
|
464
|
+
params = {};
|
465
|
+
|
466
|
+
inputs.clear();
|
467
|
+
|
468
|
+
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
469
|
+
|
470
|
+
ggml_init_params params = {
|
471
|
+
/*.mem_size =*/ buf_compute_meta.size(),
|
472
|
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
473
|
+
/*.no_alloc =*/ true,
|
474
|
+
};
|
475
|
+
|
476
|
+
ctx_compute.reset(ggml_init(params));
|
477
|
+
|
478
|
+
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
|
479
|
+
}
|
480
|
+
|
481
|
+
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
482
|
+
for (auto & input : inputs) {
|
483
|
+
input->set_input(ubatch);
|
484
|
+
}
|
485
|
+
}
|
486
|
+
|
487
|
+
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
488
|
+
if (!this->params.allow_reuse(params)) {
|
489
|
+
if (debug > 1) {
|
490
|
+
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
|
491
|
+
}
|
492
|
+
|
493
|
+
return false;
|
494
|
+
}
|
495
|
+
|
496
|
+
if (debug > 1) {
|
497
|
+
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
|
498
|
+
}
|
499
|
+
|
500
|
+
bool res = true;
|
501
|
+
|
502
|
+
for (auto & input : inputs) {
|
503
|
+
const bool cur = input->can_reuse(params);
|
504
|
+
|
505
|
+
if (debug > 1) {
|
506
|
+
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
|
507
|
+
}
|
508
|
+
|
509
|
+
res = res && cur;
|
510
|
+
}
|
511
|
+
|
512
|
+
if (debug > 0) {
|
513
|
+
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
|
514
|
+
}
|
515
|
+
|
516
|
+
return res;
|
517
|
+
}
|
518
|
+
|
519
|
+
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
|
520
|
+
inputs.emplace_back(std::move(input));
|
521
|
+
return inputs.back().get();
|
522
|
+
}
|
523
|
+
|
524
|
+
void llm_graph_result::set_params(const llm_graph_params & params) {
|
525
|
+
this->params = params;
|
526
|
+
}
|
527
|
+
|
412
528
|
//
|
413
529
|
// llm_graph_context
|
414
530
|
//
|
@@ -443,21 +559,19 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
443
559
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
444
560
|
pooling_type (cparams.pooling_type),
|
445
561
|
rope_type (hparams.rope_type),
|
446
|
-
ctx0 (params.ctx),
|
447
562
|
sched (params.sched),
|
448
563
|
backend_cpu (params.backend_cpu),
|
449
564
|
cvec (params.cvec),
|
450
565
|
loras (params.loras),
|
451
|
-
|
566
|
+
mctx (params.mctx),
|
452
567
|
cross (params.cross),
|
453
568
|
cb_func (params.cb),
|
454
|
-
res (
|
569
|
+
res (params.res),
|
570
|
+
ctx0 (res->get_ctx()),
|
571
|
+
gf (res->get_gf()) {
|
572
|
+
res->set_params(params);
|
455
573
|
}
|
456
574
|
|
457
|
-
int64_t llm_graph_context::n_pos_per_embd() const {
|
458
|
-
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
459
|
-
}
|
460
|
-
|
461
575
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
462
576
|
if (cb_func) {
|
463
577
|
cb_func(ubatch, cur, name, il);
|
@@ -617,12 +731,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
617
731
|
|
618
732
|
switch (type_op) {
|
619
733
|
case LLM_FFN_SILU:
|
620
|
-
{
|
734
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
735
|
+
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
736
|
+
cb(cur, "ffn_swiglu", il);
|
737
|
+
type_gate = LLM_FFN_SEQ;
|
738
|
+
} else {
|
621
739
|
cur = ggml_silu(ctx0, cur);
|
622
740
|
cb(cur, "ffn_silu", il);
|
623
741
|
} break;
|
624
742
|
case LLM_FFN_GELU:
|
625
|
-
{
|
743
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
744
|
+
cur = ggml_geglu_split(ctx0, cur, tmp);
|
745
|
+
cb(cur, "ffn_geglu", il);
|
746
|
+
type_gate = LLM_FFN_SEQ;
|
747
|
+
} else {
|
626
748
|
cur = ggml_gelu(ctx0, cur);
|
627
749
|
cb(cur, "ffn_gelu", il);
|
628
750
|
if (act_scales != NULL) {
|
@@ -631,7 +753,11 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
631
753
|
}
|
632
754
|
} break;
|
633
755
|
case LLM_FFN_RELU:
|
634
|
-
{
|
756
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
757
|
+
cur = ggml_reglu_split(ctx0, cur, tmp);
|
758
|
+
cb(cur, "ffn_reglu", il);
|
759
|
+
type_gate = LLM_FFN_SEQ;
|
760
|
+
} else {
|
635
761
|
cur = ggml_relu(ctx0, cur);
|
636
762
|
cb(cur, "ffn_relu", il);
|
637
763
|
} break;
|
@@ -645,17 +771,21 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
645
771
|
} break;
|
646
772
|
case LLM_FFN_SWIGLU:
|
647
773
|
{
|
648
|
-
|
649
|
-
|
650
|
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
651
|
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
652
|
-
|
653
|
-
x0 = ggml_silu(ctx0, x0);
|
654
|
-
cb(cur, "ffn_silu", il);
|
655
|
-
|
656
|
-
cur = ggml_mul(ctx0, x0, x1);
|
657
|
-
cb(cur, "ffn_mul", il);
|
774
|
+
cur = ggml_swiglu(ctx0, cur);
|
775
|
+
cb(cur, "ffn_swiglu", il);
|
658
776
|
} break;
|
777
|
+
case LLM_FFN_GEGLU:
|
778
|
+
{
|
779
|
+
cur = ggml_geglu(ctx0, cur);
|
780
|
+
cb(cur, "ffn_geglu", il);
|
781
|
+
} break;
|
782
|
+
case LLM_FFN_REGLU:
|
783
|
+
{
|
784
|
+
cur = ggml_reglu(ctx0, cur);
|
785
|
+
cb(cur, "ffn_reglu", il);
|
786
|
+
} break;
|
787
|
+
default:
|
788
|
+
GGML_ABORT("fatal error");
|
659
789
|
}
|
660
790
|
|
661
791
|
if (gate && type_gate == LLM_FFN_PAR) {
|
@@ -665,8 +795,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
665
795
|
|
666
796
|
if (down) {
|
667
797
|
cur = build_lora_mm(down, cur);
|
668
|
-
if (arch == LLM_ARCH_GLM4) {
|
669
|
-
// GLM4
|
798
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
799
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
670
800
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
671
801
|
}
|
672
802
|
}
|
@@ -701,13 +831,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
701
831
|
bool scale_w,
|
702
832
|
float w_scale,
|
703
833
|
llama_expert_gating_func_type gating_op,
|
704
|
-
int il
|
834
|
+
int il,
|
835
|
+
ggml_tensor * probs_in) const {
|
836
|
+
return build_moe_ffn(
|
837
|
+
cur,
|
838
|
+
gate_inp, /* gate_inp_b */ nullptr,
|
839
|
+
up_exps, /* up_exps_b */ nullptr,
|
840
|
+
gate_exps, /* gate_exps_b */ nullptr,
|
841
|
+
down_exps, /* down_exps_b */ nullptr,
|
842
|
+
exp_probs_b,
|
843
|
+
n_expert,
|
844
|
+
n_expert_used,
|
845
|
+
type_op,
|
846
|
+
norm_w,
|
847
|
+
scale_w,
|
848
|
+
w_scale,
|
849
|
+
gating_op,
|
850
|
+
il,
|
851
|
+
probs_in
|
852
|
+
);
|
853
|
+
}
|
854
|
+
|
855
|
+
ggml_tensor * llm_graph_context::build_moe_ffn(
|
856
|
+
ggml_tensor * cur,
|
857
|
+
ggml_tensor * gate_inp,
|
858
|
+
ggml_tensor * gate_inp_b,
|
859
|
+
ggml_tensor * up_exps,
|
860
|
+
ggml_tensor * up_exps_b,
|
861
|
+
ggml_tensor * gate_exps,
|
862
|
+
ggml_tensor * gate_exps_b,
|
863
|
+
ggml_tensor * down_exps,
|
864
|
+
ggml_tensor * down_exps_b,
|
865
|
+
ggml_tensor * exp_probs_b,
|
866
|
+
int64_t n_expert,
|
867
|
+
int64_t n_expert_used,
|
868
|
+
llm_ffn_op_type type_op,
|
869
|
+
bool norm_w,
|
870
|
+
bool scale_w,
|
871
|
+
float w_scale,
|
872
|
+
llama_expert_gating_func_type gating_op,
|
873
|
+
int il,
|
874
|
+
ggml_tensor * probs_in) const {
|
705
875
|
const int64_t n_embd = cur->ne[0];
|
706
876
|
const int64_t n_tokens = cur->ne[1];
|
707
877
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
708
878
|
|
709
|
-
ggml_tensor * logits =
|
710
|
-
|
879
|
+
ggml_tensor * logits = nullptr;
|
880
|
+
|
881
|
+
if (probs_in == nullptr) {
|
882
|
+
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
883
|
+
cb(logits, "ffn_moe_logits", il);
|
884
|
+
} else {
|
885
|
+
logits = probs_in;
|
886
|
+
}
|
887
|
+
|
888
|
+
if (gate_inp_b) {
|
889
|
+
logits = ggml_add(ctx0, logits, gate_inp_b);
|
890
|
+
cb(logits, "ffn_moe_logits_biased", il);
|
891
|
+
}
|
711
892
|
|
712
893
|
ggml_tensor * probs = nullptr;
|
713
894
|
switch (gating_op) {
|
@@ -719,6 +900,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
719
900
|
{
|
720
901
|
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
721
902
|
} break;
|
903
|
+
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
|
904
|
+
{
|
905
|
+
probs = logits; // [n_expert, n_tokens]
|
906
|
+
} break;
|
722
907
|
default:
|
723
908
|
GGML_ABORT("fatal error");
|
724
909
|
}
|
@@ -738,15 +923,36 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
738
923
|
selection_probs = logits;
|
739
924
|
}
|
740
925
|
|
926
|
+
if (arch == LLM_ARCH_GROVEMOE) {
|
927
|
+
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
928
|
+
cb(selection_probs, "ffn_moe_probs_biased", il);
|
929
|
+
}
|
930
|
+
|
741
931
|
// select experts
|
742
932
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
743
933
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
744
934
|
cb(selected_experts, "ffn_moe_topk", il);
|
745
935
|
|
746
|
-
|
747
|
-
|
936
|
+
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
|
937
|
+
// TODO: Use scalar div instead when/if implemented
|
938
|
+
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
|
939
|
+
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
|
940
|
+
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
|
941
|
+
} else {
|
942
|
+
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
|
943
|
+
}
|
944
|
+
|
945
|
+
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
|
748
946
|
cb(weights, "ffn_moe_weights", il);
|
749
947
|
|
948
|
+
|
949
|
+
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
950
|
+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
951
|
+
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
|
952
|
+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
953
|
+
cb(weights, "ffn_moe_weights_softmax", il);
|
954
|
+
}
|
955
|
+
|
750
956
|
if (norm_w) {
|
751
957
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
752
958
|
|
@@ -763,12 +969,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
763
969
|
cb(weights, "ffn_moe_weights_scaled", il);
|
764
970
|
}
|
765
971
|
|
972
|
+
//call early so that topk-moe can be used
|
973
|
+
ggml_build_forward_expand(gf, weights);
|
974
|
+
|
766
975
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
767
976
|
|
768
977
|
if (weight_before_ffn) {
|
769
|
-
//
|
770
|
-
ggml_tensor * repeated =
|
771
|
-
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
978
|
+
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
979
|
+
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
772
980
|
cur = ggml_mul(ctx0, repeated, weights);
|
773
981
|
cb(cur, "ffn_moe_weighted", il);
|
774
982
|
}
|
@@ -776,6 +984,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
776
984
|
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
777
985
|
cb(up, "ffn_moe_up", il);
|
778
986
|
|
987
|
+
if (up_exps_b) {
|
988
|
+
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
989
|
+
cb(up, "ffn_moe_up_biased", il);
|
990
|
+
}
|
991
|
+
|
779
992
|
ggml_tensor * experts = nullptr;
|
780
993
|
if (gate_exps) {
|
781
994
|
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
@@ -784,48 +997,83 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
784
997
|
cur = up;
|
785
998
|
}
|
786
999
|
|
1000
|
+
if (gate_exps_b) {
|
1001
|
+
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
1002
|
+
cb(cur, "ffn_moe_gate_biased", il);
|
1003
|
+
}
|
1004
|
+
|
787
1005
|
switch (type_op) {
|
788
1006
|
case LLM_FFN_SILU:
|
789
|
-
{
|
1007
|
+
if (gate_exps) {
|
1008
|
+
cur = ggml_swiglu_split(ctx0, cur, up);
|
1009
|
+
cb(cur, "ffn_moe_swiglu", il);
|
1010
|
+
} else {
|
790
1011
|
cur = ggml_silu(ctx0, cur);
|
791
1012
|
cb(cur, "ffn_moe_silu", il);
|
792
1013
|
} break;
|
793
1014
|
case LLM_FFN_GELU:
|
794
|
-
{
|
1015
|
+
if (gate_exps) {
|
1016
|
+
cur = ggml_geglu_split(ctx0, cur, up);
|
1017
|
+
cb(cur, "ffn_moe_geglu", il);
|
1018
|
+
} else {
|
795
1019
|
cur = ggml_gelu(ctx0, cur);
|
796
1020
|
cb(cur, "ffn_moe_gelu", il);
|
797
1021
|
} break;
|
1022
|
+
case LLM_FFN_SWIGLU_OAI_MOE:
|
1023
|
+
{
|
1024
|
+
// TODO: move to hparams?
|
1025
|
+
constexpr float alpha = 1.702f;
|
1026
|
+
constexpr float limit = 7.0f;
|
1027
|
+
cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
|
1028
|
+
cb(cur, "ffn_moe_swiglu_oai", il);
|
1029
|
+
} break;
|
1030
|
+
case LLM_FFN_RELU:
|
1031
|
+
if (gate_exps) {
|
1032
|
+
cur = ggml_reglu_split(ctx0, cur, up);
|
1033
|
+
cb(cur, "ffn_moe_reglu", il);
|
1034
|
+
} else {
|
1035
|
+
cur = ggml_relu(ctx0, cur);
|
1036
|
+
cb(cur, "ffn_moe_relu", il);
|
1037
|
+
} break;
|
798
1038
|
default:
|
799
1039
|
GGML_ABORT("fatal error");
|
800
1040
|
}
|
801
1041
|
|
802
|
-
if (gate_exps) {
|
803
|
-
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
804
|
-
cb(cur, "ffn_moe_gate_par", il);
|
805
|
-
}
|
806
|
-
|
807
1042
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
808
1043
|
cb(experts, "ffn_moe_down", il);
|
809
1044
|
|
1045
|
+
if (down_exps_b) {
|
1046
|
+
experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
|
1047
|
+
cb(experts, "ffn_moe_down_biased", il);
|
1048
|
+
}
|
1049
|
+
|
810
1050
|
if (!weight_before_ffn) {
|
811
1051
|
experts = ggml_mul(ctx0, experts, weights);
|
812
1052
|
cb(cur, "ffn_moe_weighted", il);
|
813
1053
|
}
|
814
1054
|
|
1055
|
+
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
|
1056
|
+
|
1057
|
+
assert(n_expert_used > 0);
|
1058
|
+
|
1059
|
+
// order the views before the adds
|
1060
|
+
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
|
1061
|
+
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
1062
|
+
|
1063
|
+
ggml_build_forward_expand(gf, cur_experts[i]);
|
1064
|
+
}
|
1065
|
+
|
815
1066
|
// aggregate experts
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
1067
|
+
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
|
1068
|
+
// to avoid potentially a large number of add nodes during warmup
|
1069
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
|
1070
|
+
ggml_tensor * moe_out = cur_experts[0];
|
820
1071
|
|
821
|
-
|
822
|
-
|
823
|
-
} else {
|
824
|
-
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
825
|
-
}
|
1072
|
+
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
|
1073
|
+
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
|
826
1074
|
}
|
827
1075
|
|
828
|
-
if (n_expert_used == 1) {
|
1076
|
+
if (hparams.n_expert_used == 1) {
|
829
1077
|
// avoid returning a non-contiguous tensor
|
830
1078
|
moe_out = ggml_cont(ctx0, moe_out);
|
831
1079
|
}
|
@@ -888,11 +1136,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
888
1136
|
}
|
889
1137
|
|
890
1138
|
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
891
|
-
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
1139
|
+
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
892
1140
|
|
893
1141
|
auto & cur = inp->pos;
|
894
1142
|
|
895
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
1143
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
|
896
1144
|
ggml_set_input(cur);
|
897
1145
|
|
898
1146
|
res->add_input(std::move(inp));
|
@@ -915,6 +1163,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
915
1163
|
}
|
916
1164
|
|
917
1165
|
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
1166
|
+
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
1167
|
+
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
1168
|
+
// features that require constant topology such as pipline parallelism
|
1169
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
1170
|
+
//if (n_outputs < n_tokens) {
|
1171
|
+
// return nullptr;
|
1172
|
+
//}
|
1173
|
+
|
918
1174
|
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
919
1175
|
|
920
1176
|
auto & cur = inp->out_ids;
|
@@ -932,7 +1188,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
932
1188
|
|
933
1189
|
auto & cur = inp->mean;
|
934
1190
|
|
935
|
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens,
|
1191
|
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
|
936
1192
|
ggml_set_input(cur);
|
937
1193
|
|
938
1194
|
res->add_input(std::move(inp));
|
@@ -941,45 +1197,11 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
941
1197
|
}
|
942
1198
|
|
943
1199
|
ggml_tensor * llm_graph_context::build_inp_cls() const {
|
944
|
-
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
|
1200
|
+
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
|
945
1201
|
|
946
1202
|
auto & cur = inp->cls;
|
947
1203
|
|
948
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32,
|
949
|
-
ggml_set_input(cur);
|
950
|
-
|
951
|
-
res->add_input(std::move(inp));
|
952
|
-
|
953
|
-
return cur;
|
954
|
-
}
|
955
|
-
|
956
|
-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
957
|
-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
958
|
-
|
959
|
-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
960
|
-
|
961
|
-
const auto n_kv = kv_self->n;
|
962
|
-
|
963
|
-
auto & cur = inp->s_copy;
|
964
|
-
|
965
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
966
|
-
ggml_set_input(cur);
|
967
|
-
|
968
|
-
res->add_input(std::move(inp));
|
969
|
-
|
970
|
-
return cur;
|
971
|
-
}
|
972
|
-
|
973
|
-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
974
|
-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
975
|
-
|
976
|
-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
977
|
-
|
978
|
-
const auto n_kv = kv_self->n;
|
979
|
-
|
980
|
-
auto & cur = inp->s_mask;
|
981
|
-
|
982
|
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
1204
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
|
983
1205
|
ggml_set_input(cur);
|
984
1206
|
|
985
1207
|
res->add_input(std::move(inp));
|
@@ -1025,11 +1247,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
1025
1247
|
}
|
1026
1248
|
|
1027
1249
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
1028
|
-
const
|
1250
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
1029
1251
|
|
1030
|
-
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams,
|
1252
|
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
1031
1253
|
|
1032
|
-
const auto n_kv =
|
1254
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
1033
1255
|
|
1034
1256
|
auto & cur = inp->pos_bucket;
|
1035
1257
|
|
@@ -1057,23 +1279,27 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
1057
1279
|
}
|
1058
1280
|
|
1059
1281
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
1060
|
-
ggml_cgraph * gf,
|
1061
1282
|
ggml_tensor * q,
|
1062
1283
|
ggml_tensor * k,
|
1063
1284
|
ggml_tensor * v,
|
1064
1285
|
ggml_tensor * kq_b,
|
1065
1286
|
ggml_tensor * kq_mask,
|
1287
|
+
ggml_tensor * sinks,
|
1066
1288
|
ggml_tensor * v_mla,
|
1067
|
-
|
1289
|
+
float kq_scale,
|
1290
|
+
int il) const {
|
1068
1291
|
const bool v_trans = v->nb[1] > v->nb[2];
|
1069
1292
|
|
1293
|
+
// split the batch into streams if needed
|
1294
|
+
const auto n_stream = k->ne[3];
|
1295
|
+
|
1296
|
+
q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
|
1297
|
+
|
1070
1298
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
1071
1299
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
1072
1300
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
1073
1301
|
|
1074
|
-
const auto
|
1075
|
-
const auto n_head = q->ne[2];
|
1076
|
-
const auto n_kv = k->ne[1];
|
1302
|
+
const auto n_kv = k->ne[1];
|
1077
1303
|
|
1078
1304
|
ggml_tensor * cur;
|
1079
1305
|
|
@@ -1096,8 +1322,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1096
1322
|
|
1097
1323
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
1098
1324
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
1325
|
+
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
1099
1326
|
|
1100
|
-
|
1327
|
+
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
1328
|
+
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
1101
1329
|
|
1102
1330
|
if (v_mla) {
|
1103
1331
|
#if 0
|
@@ -1110,14 +1338,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1110
1338
|
// The permutations are noops and only change how the tensor data is interpreted.
|
1111
1339
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1112
1340
|
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
1341
|
+
cb(cur, "fattn_mla", il);
|
1113
1342
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1114
1343
|
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
1115
1344
|
#endif
|
1116
1345
|
}
|
1117
1346
|
|
1118
|
-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*
|
1347
|
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
1119
1348
|
} else {
|
1120
1349
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
1350
|
+
cb(kq, "kq", il);
|
1121
1351
|
|
1122
1352
|
// note: this op tends to require high floating point range
|
1123
1353
|
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
@@ -1125,42 +1355,54 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1125
1355
|
|
1126
1356
|
if (arch == LLM_ARCH_GROK) {
|
1127
1357
|
// need to do the following:
|
1128
|
-
// multiply by
|
1358
|
+
// multiply by attn_output_multiplier
|
1129
1359
|
// and then :
|
1130
1360
|
// kq = 30 * tanh(kq / 30)
|
1131
1361
|
// before the softmax below
|
1132
1362
|
|
1133
|
-
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq,
|
1134
|
-
kq
|
1363
|
+
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
|
1364
|
+
cb(kq, "kq_tanh", il);
|
1365
|
+
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
1366
|
+
cb(kq, "kq_scaled", il);
|
1135
1367
|
}
|
1136
1368
|
|
1137
1369
|
if (hparams.attn_soft_cap) {
|
1138
1370
|
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
1371
|
+
cb(kq, "kq_scaled_1", il);
|
1139
1372
|
kq = ggml_tanh (ctx0, kq);
|
1373
|
+
cb(kq, "kq_tanh", il);
|
1140
1374
|
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
1375
|
+
cb(kq, "kq_scaled_2", il);
|
1141
1376
|
}
|
1142
1377
|
|
1143
1378
|
if (kq_b) {
|
1144
1379
|
kq = ggml_add(ctx0, kq, kq_b);
|
1380
|
+
cb(kq, "kq_plus_kq_b", il);
|
1145
1381
|
}
|
1146
1382
|
|
1147
1383
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
1384
|
+
ggml_soft_max_add_sinks(kq, sinks);
|
1385
|
+
cb(kq, "kq_soft_max", il);
|
1148
1386
|
|
1149
1387
|
if (!v_trans) {
|
1150
1388
|
// note: avoid this branch
|
1151
1389
|
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
1390
|
+
cb(v, "v_cont", il);
|
1152
1391
|
}
|
1153
1392
|
|
1154
1393
|
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
1394
|
+
cb(kqv, "kqv", il);
|
1155
1395
|
|
1156
1396
|
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
1157
1397
|
if (v_mla) {
|
1158
1398
|
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
1399
|
+
cb(kqv, "kqv_mla", il);
|
1159
1400
|
}
|
1160
1401
|
|
1161
1402
|
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
1162
1403
|
|
1163
|
-
|
1404
|
+
// recombine streams
|
1405
|
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
1164
1406
|
|
1165
1407
|
if (!cparams.offload_kqv) {
|
1166
1408
|
// all nodes between the KV store and the attention output are run on the CPU
|
@@ -1177,8 +1419,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
1177
1419
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
1178
1420
|
|
1179
1421
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
1180
|
-
inp->kq_mask =
|
1181
|
-
//cb(inp_kq_mask, "KQ_mask", -1);
|
1422
|
+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
1182
1423
|
ggml_set_input(inp->kq_mask);
|
1183
1424
|
|
1184
1425
|
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
@@ -1188,13 +1429,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
1188
1429
|
|
1189
1430
|
ggml_tensor * llm_graph_context::build_attn(
|
1190
1431
|
llm_graph_input_attn_no_cache * inp,
|
1191
|
-
ggml_cgraph * gf,
|
1192
1432
|
ggml_tensor * wo,
|
1193
1433
|
ggml_tensor * wo_b,
|
1194
1434
|
ggml_tensor * q_cur,
|
1195
1435
|
ggml_tensor * k_cur,
|
1196
1436
|
ggml_tensor * v_cur,
|
1197
1437
|
ggml_tensor * kq_b,
|
1438
|
+
ggml_tensor * sinks,
|
1198
1439
|
ggml_tensor * v_mla,
|
1199
1440
|
float kq_scale,
|
1200
1441
|
int il) const {
|
@@ -1208,11 +1449,16 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1208
1449
|
|
1209
1450
|
const auto & kq_mask = inp->get_kq_mask();
|
1210
1451
|
|
1452
|
+
// [TAG_NO_CACHE_PAD]
|
1453
|
+
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
1454
|
+
// but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
|
1455
|
+
//assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
|
1456
|
+
|
1211
1457
|
ggml_tensor * q = q_cur;
|
1212
1458
|
ggml_tensor * k = k_cur;
|
1213
1459
|
ggml_tensor * v = v_cur;
|
1214
1460
|
|
1215
|
-
ggml_tensor * cur = build_attn_mha(
|
1461
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1216
1462
|
cb(cur, "kqv_out", il);
|
1217
1463
|
|
1218
1464
|
if (wo) {
|
@@ -1230,35 +1476,51 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1230
1476
|
return cur;
|
1231
1477
|
}
|
1232
1478
|
|
1233
|
-
|
1234
|
-
|
1479
|
+
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
1480
|
+
ggml_context * ctx0,
|
1481
|
+
const llama_ubatch & ubatch,
|
1482
|
+
const llama_hparams & hparams,
|
1483
|
+
const llama_cparams & cparams,
|
1484
|
+
const llama_kv_cache_context * mctx_cur) {
|
1235
1485
|
|
1236
|
-
auto inp = std::make_unique<
|
1486
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
|
1237
1487
|
|
1238
1488
|
{
|
1239
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use
|
1489
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
1490
|
+
|
1491
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
1492
|
+
const auto n_tokens = ubatch.n_tokens;
|
1493
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
1240
1494
|
|
1241
|
-
|
1495
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
1496
|
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
1242
1497
|
|
1243
|
-
inp->self_kq_mask =
|
1244
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1498
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
1245
1499
|
ggml_set_input(inp->self_kq_mask);
|
1246
1500
|
|
1247
1501
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
1248
1502
|
}
|
1249
1503
|
|
1250
|
-
return
|
1504
|
+
return inp;
|
1505
|
+
}
|
1506
|
+
|
1507
|
+
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
|
1508
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
1509
|
+
|
1510
|
+
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
1511
|
+
|
1512
|
+
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
|
1251
1513
|
}
|
1252
1514
|
|
1253
1515
|
ggml_tensor * llm_graph_context::build_attn(
|
1254
|
-
|
1255
|
-
ggml_cgraph * gf,
|
1516
|
+
llm_graph_input_attn_kv * inp,
|
1256
1517
|
ggml_tensor * wo,
|
1257
1518
|
ggml_tensor * wo_b,
|
1258
1519
|
ggml_tensor * q_cur,
|
1259
1520
|
ggml_tensor * k_cur,
|
1260
1521
|
ggml_tensor * v_cur,
|
1261
1522
|
ggml_tensor * kq_b,
|
1523
|
+
ggml_tensor * sinks,
|
1262
1524
|
ggml_tensor * v_mla,
|
1263
1525
|
float kq_scale,
|
1264
1526
|
int il) const {
|
@@ -1268,27 +1530,30 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1268
1530
|
ggml_build_forward_expand(gf, k_cur);
|
1269
1531
|
ggml_build_forward_expand(gf, v_cur);
|
1270
1532
|
|
1271
|
-
const
|
1533
|
+
const auto * mctx_cur = inp->mctx;
|
1272
1534
|
|
1273
1535
|
// store to KV cache
|
1274
1536
|
{
|
1275
|
-
|
1276
|
-
|
1537
|
+
const auto & k_idxs = inp->get_k_idxs();
|
1538
|
+
const auto & v_idxs = inp->get_v_idxs();
|
1539
|
+
|
1540
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
1541
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
1277
1542
|
}
|
1278
1543
|
|
1279
1544
|
const auto & kq_mask = inp->get_kq_mask();
|
1280
1545
|
|
1281
1546
|
ggml_tensor * q = q_cur;
|
1282
|
-
ggml_tensor * k =
|
1283
|
-
ggml_tensor * v =
|
1547
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
1548
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
1284
1549
|
|
1285
|
-
ggml_tensor * cur = build_attn_mha(
|
1550
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1286
1551
|
cb(cur, "kqv_out", il);
|
1287
1552
|
|
1288
1553
|
if (wo) {
|
1289
1554
|
cur = build_lora_mm(wo, cur);
|
1290
|
-
if (arch == LLM_ARCH_GLM4) {
|
1291
|
-
// GLM4
|
1555
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
1556
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
1292
1557
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
1293
1558
|
}
|
1294
1559
|
}
|
@@ -1300,73 +1565,56 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1300
1565
|
return cur;
|
1301
1566
|
}
|
1302
1567
|
|
1303
|
-
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
1304
|
-
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
1305
|
-
|
1306
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
1307
|
-
|
1308
|
-
{
|
1309
|
-
const auto n_kv = kv_self->get_kv_base()->get_n();
|
1310
|
-
|
1311
|
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
1312
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1313
|
-
ggml_set_input(inp->self_kq_mask);
|
1314
|
-
|
1315
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
1316
|
-
}
|
1317
|
-
|
1318
|
-
{
|
1319
|
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
1320
|
-
|
1321
|
-
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
1322
|
-
|
1323
|
-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
1324
|
-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
1325
|
-
ggml_set_input(inp->self_kq_mask_swa);
|
1326
|
-
|
1327
|
-
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
1328
|
-
}
|
1329
|
-
|
1330
|
-
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
1331
|
-
}
|
1332
|
-
|
1333
1568
|
ggml_tensor * llm_graph_context::build_attn(
|
1334
|
-
|
1335
|
-
ggml_cgraph * gf,
|
1569
|
+
llm_graph_input_attn_kv_iswa * inp,
|
1336
1570
|
ggml_tensor * wo,
|
1337
1571
|
ggml_tensor * wo_b,
|
1338
1572
|
ggml_tensor * q_cur,
|
1339
1573
|
ggml_tensor * k_cur,
|
1340
1574
|
ggml_tensor * v_cur,
|
1341
1575
|
ggml_tensor * kq_b,
|
1576
|
+
ggml_tensor * sinks,
|
1342
1577
|
ggml_tensor * v_mla,
|
1343
1578
|
float kq_scale,
|
1344
1579
|
int il) const {
|
1345
1580
|
// these nodes are added to the graph together so that they are not reordered
|
1346
1581
|
// by doing so, the number of splits in the graph is reduced
|
1347
1582
|
ggml_build_forward_expand(gf, q_cur);
|
1348
|
-
|
1349
|
-
|
1583
|
+
|
1584
|
+
if (k_cur) {
|
1585
|
+
ggml_build_forward_expand(gf, k_cur);
|
1586
|
+
}
|
1587
|
+
|
1588
|
+
if (v_cur) {
|
1589
|
+
ggml_build_forward_expand(gf, v_cur);
|
1590
|
+
}
|
1591
|
+
|
1592
|
+
const auto * mctx_iswa = inp->mctx;
|
1350
1593
|
|
1351
1594
|
const bool is_swa = hparams.is_swa(il);
|
1352
1595
|
|
1353
|
-
const
|
1596
|
+
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
1354
1597
|
|
1355
|
-
|
1598
|
+
// optionally store to KV cache
|
1599
|
+
if (k_cur) {
|
1600
|
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
1356
1601
|
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1602
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
1603
|
+
}
|
1604
|
+
|
1605
|
+
if (v_cur) {
|
1606
|
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
1607
|
+
|
1608
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
1361
1609
|
}
|
1362
1610
|
|
1363
1611
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
1364
1612
|
|
1365
1613
|
ggml_tensor * q = q_cur;
|
1366
|
-
ggml_tensor * k =
|
1367
|
-
ggml_tensor * v =
|
1614
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
1615
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
1368
1616
|
|
1369
|
-
ggml_tensor * cur = build_attn_mha(
|
1617
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1370
1618
|
cb(cur, "kqv_out", il);
|
1371
1619
|
|
1372
1620
|
if (wo) {
|
@@ -1389,7 +1637,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
1389
1637
|
|
1390
1638
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
1391
1639
|
|
1392
|
-
inp->cross_kq_mask =
|
1640
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
1393
1641
|
ggml_set_input(inp->cross_kq_mask);
|
1394
1642
|
|
1395
1643
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
@@ -1399,13 +1647,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
1399
1647
|
|
1400
1648
|
ggml_tensor * llm_graph_context::build_attn(
|
1401
1649
|
llm_graph_input_attn_cross * inp,
|
1402
|
-
ggml_cgraph * gf,
|
1403
1650
|
ggml_tensor * wo,
|
1404
1651
|
ggml_tensor * wo_b,
|
1405
1652
|
ggml_tensor * q_cur,
|
1406
1653
|
ggml_tensor * k_cur,
|
1407
1654
|
ggml_tensor * v_cur,
|
1408
1655
|
ggml_tensor * kq_b,
|
1656
|
+
ggml_tensor * sinks,
|
1409
1657
|
ggml_tensor * v_mla,
|
1410
1658
|
float kq_scale,
|
1411
1659
|
int il) const {
|
@@ -1421,7 +1669,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1421
1669
|
ggml_tensor * k = k_cur;
|
1422
1670
|
ggml_tensor * v = v_cur;
|
1423
1671
|
|
1424
|
-
ggml_tensor * cur = build_attn_mha(
|
1672
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1425
1673
|
cb(cur, "kqv_out", il);
|
1426
1674
|
|
1427
1675
|
if (wo) {
|
@@ -1439,56 +1687,135 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1439
1687
|
return cur;
|
1440
1688
|
}
|
1441
1689
|
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
int32_t n_state,
|
1448
|
-
int32_t n_seqs) const {
|
1449
|
-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1690
|
+
// TODO: maybe separate the inner implementation into a separate function
|
1691
|
+
// like with the non-sliding window equivalent
|
1692
|
+
// once sliding-window hybrid caches are a thing.
|
1693
|
+
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
|
1694
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
|
1450
1695
|
|
1451
|
-
|
1452
|
-
const auto kv_head = kv_self->head;
|
1696
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
1453
1697
|
|
1454
|
-
|
1698
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
1455
1699
|
|
1456
|
-
|
1457
|
-
|
1458
|
-
// this shrinks the tensors's ne[1] to n_kv
|
1459
|
-
states = ggml_get_rows(ctx0, states, state_copy);
|
1700
|
+
{
|
1701
|
+
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
1460
1702
|
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1703
|
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
1704
|
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
1705
|
+
|
1706
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
1707
|
+
ggml_set_input(inp->self_kq_mask);
|
1464
1708
|
|
1465
|
-
|
1709
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
1710
|
+
}
|
1711
|
+
|
1712
|
+
{
|
1713
|
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
1714
|
+
|
1715
|
+
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
1716
|
+
|
1717
|
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
1718
|
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
1719
|
+
|
1720
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
1721
|
+
ggml_set_input(inp->self_kq_mask_swa);
|
1722
|
+
|
1723
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
1724
|
+
}
|
1725
|
+
|
1726
|
+
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
1727
|
+
}
|
1728
|
+
|
1729
|
+
ggml_tensor * llm_graph_context::build_rs(
|
1730
|
+
ggml_tensor * s,
|
1731
|
+
ggml_tensor * state_copy_main,
|
1732
|
+
ggml_tensor * state_copy_extra,
|
1733
|
+
int32_t state_size,
|
1734
|
+
int32_t n_seqs,
|
1735
|
+
uint32_t n_rs,
|
1736
|
+
uint32_t rs_head,
|
1737
|
+
uint32_t rs_size,
|
1738
|
+
int32_t rs_zero,
|
1739
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
1740
|
+
|
1741
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
|
1742
|
+
|
1743
|
+
// Clear a single state which will then be copied to the other cleared states.
|
1744
|
+
// Note that this is a no-op when the view is zero-sized.
|
1745
|
+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
1746
|
+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
1747
|
+
|
1748
|
+
// copy states
|
1749
|
+
// NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
|
1750
|
+
// {state_size, rs_size} -> {state_size, n_seqs}
|
1751
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
|
1752
|
+
ggml_build_forward_expand(gf, output_states);
|
1753
|
+
|
1754
|
+
// copy extra states which won't be changed further (between n_seqs and n_rs)
|
1755
|
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
|
1466
1756
|
ggml_build_forward_expand(gf,
|
1467
1757
|
ggml_cpy(ctx0,
|
1468
|
-
|
1469
|
-
ggml_view_1d(ctx0, s,
|
1758
|
+
states_extra,
|
1759
|
+
ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
|
1760
|
+
|
1761
|
+
return output_states;
|
1762
|
+
}
|
1763
|
+
|
1764
|
+
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
1765
|
+
ggml_context * ctx0,
|
1766
|
+
const llama_ubatch & ubatch,
|
1767
|
+
const llama_memory_recurrent_context * mctx_cur) {
|
1768
|
+
|
1769
|
+
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
1770
|
+
|
1771
|
+
const int64_t n_rs = mctx_cur->get_n_rs();
|
1772
|
+
const int64_t n_seqs = ubatch.n_seqs;
|
1773
|
+
|
1774
|
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
1775
|
+
ggml_set_input(inp->s_copy);
|
1776
|
+
|
1777
|
+
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
1778
|
+
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
1470
1779
|
|
1471
|
-
|
1472
|
-
|
1780
|
+
return inp;
|
1781
|
+
}
|
1782
|
+
|
1783
|
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
1784
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
1785
|
+
|
1786
|
+
auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
|
1787
|
+
|
1788
|
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
1789
|
+
}
|
1790
|
+
|
1791
|
+
ggml_tensor * llm_graph_context::build_rs(
|
1792
|
+
llm_graph_input_rs * inp,
|
1793
|
+
ggml_tensor * s,
|
1794
|
+
int32_t state_size,
|
1795
|
+
int32_t n_seqs,
|
1796
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
1797
|
+
const auto * kv_state = inp->mctx;
|
1798
|
+
|
1799
|
+
return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
|
1800
|
+
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
|
1801
|
+
get_state_rows);
|
1473
1802
|
}
|
1474
1803
|
|
1475
1804
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
int il) const {
|
1481
|
-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1805
|
+
llm_graph_input_rs * inp,
|
1806
|
+
const llama_ubatch & ubatch,
|
1807
|
+
int il) const {
|
1808
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
1482
1809
|
|
1483
1810
|
const auto token_shift_count = hparams.token_shift_count;
|
1484
1811
|
|
1485
1812
|
const int64_t n_seqs = ubatch.n_seqs;
|
1486
1813
|
|
1487
|
-
ggml_tensor * token_shift_all =
|
1814
|
+
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
1488
1815
|
|
1489
|
-
ggml_tensor * token_shift =
|
1490
|
-
|
1491
|
-
hparams.
|
1816
|
+
ggml_tensor * token_shift = build_rs(
|
1817
|
+
inp, token_shift_all,
|
1818
|
+
hparams.n_embd_r(), n_seqs);
|
1492
1819
|
|
1493
1820
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
1494
1821
|
|
@@ -1499,24 +1826,34 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
1499
1826
|
ggml_tensor * token_shift,
|
1500
1827
|
const llama_ubatch & ubatch,
|
1501
1828
|
int il) const {
|
1502
|
-
const
|
1829
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
1503
1830
|
|
1504
1831
|
const auto token_shift_count = hparams.token_shift_count;
|
1505
1832
|
const auto n_embd = hparams.n_embd;
|
1506
1833
|
|
1507
1834
|
const int64_t n_seqs = ubatch.n_seqs;
|
1508
1835
|
|
1509
|
-
const auto kv_head =
|
1836
|
+
const auto kv_head = mctx_cur->get_head();
|
1510
1837
|
|
1511
1838
|
return ggml_cpy(
|
1512
1839
|
ctx0,
|
1513
1840
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
1514
|
-
ggml_view_1d(ctx0,
|
1841
|
+
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
1515
1842
|
);
|
1516
1843
|
}
|
1517
1844
|
|
1845
|
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
1846
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
1847
|
+
|
1848
|
+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
1849
|
+
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
1850
|
+
|
1851
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
1852
|
+
|
1853
|
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
1854
|
+
}
|
1855
|
+
|
1518
1856
|
void llm_graph_context::build_pooling(
|
1519
|
-
ggml_cgraph * gf,
|
1520
1857
|
ggml_tensor * cls,
|
1521
1858
|
ggml_tensor * cls_b,
|
1522
1859
|
ggml_tensor * cls_out,
|
@@ -1560,22 +1897,32 @@ void llm_graph_context::build_pooling(
|
|
1560
1897
|
case LLAMA_POOLING_TYPE_RANK:
|
1561
1898
|
{
|
1562
1899
|
ggml_tensor * inp_cls = build_inp_cls();
|
1563
|
-
|
1900
|
+
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
1564
1901
|
|
1565
1902
|
// classification head
|
1566
1903
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1904
|
+
if (cls) {
|
1905
|
+
cur = ggml_mul_mat(ctx0, cls, cur);
|
1906
|
+
if (cls_b) {
|
1907
|
+
cur = ggml_add(ctx0, cur, cls_b);
|
1908
|
+
}
|
1909
|
+
cur = ggml_tanh(ctx0, cur);
|
1910
|
+
}
|
1572
1911
|
|
1573
1912
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
1574
1913
|
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
1914
|
+
// Single layer classification head (direct projection)
|
1915
|
+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
1575
1916
|
if (cls_out) {
|
1576
|
-
|
1917
|
+
cur = ggml_mul_mat(ctx0, cls_out, cur);
|
1918
|
+
if (cls_out_b) {
|
1919
|
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
1920
|
+
}
|
1921
|
+
}
|
1577
1922
|
|
1578
|
-
|
1923
|
+
// softmax for qwen3 reranker
|
1924
|
+
if (arch == LLM_ARCH_QWEN3) {
|
1925
|
+
cur = ggml_soft_max(ctx0, cur);
|
1579
1926
|
}
|
1580
1927
|
} break;
|
1581
1928
|
default:
|