whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -4,8 +4,8 @@
|
|
4
4
|
#include "llama-batch.h"
|
5
5
|
#include "llama-cparams.h"
|
6
6
|
|
7
|
-
#include "llama-kv-cache
|
8
|
-
#include "llama-kv-cache-
|
7
|
+
#include "llama-kv-cache.h"
|
8
|
+
#include "llama-kv-cache-iswa.h"
|
9
9
|
#include "llama-memory-hybrid.h"
|
10
10
|
#include "llama-memory-recurrent.h"
|
11
11
|
|
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
28
28
|
}
|
29
29
|
}
|
30
30
|
|
31
|
+
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
32
|
+
bool res = true;
|
33
|
+
|
34
|
+
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
35
|
+
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
36
|
+
|
37
|
+
return res;
|
38
|
+
}
|
39
|
+
|
31
40
|
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
32
41
|
if (ubatch->pos && pos) {
|
33
42
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
50
59
|
}
|
51
60
|
}
|
52
61
|
|
62
|
+
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
63
|
+
bool res = true;
|
64
|
+
|
65
|
+
res &= pos->ne[0] == params.ubatch.n_tokens;
|
66
|
+
|
67
|
+
return res;
|
68
|
+
}
|
69
|
+
|
53
70
|
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
54
71
|
if (ubatch->pos && attn_scale) {
|
55
72
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
71
88
|
const int64_t n_tokens = ubatch->n_tokens;
|
72
89
|
|
73
90
|
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
74
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
91
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
75
92
|
|
76
93
|
int32_t * data = (int32_t *) pos_bucket->data;
|
77
94
|
|
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
|
118
135
|
}
|
119
136
|
}
|
120
137
|
|
138
|
+
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
139
|
+
bool res = true;
|
140
|
+
|
141
|
+
res &= n_outputs == params.n_outputs;
|
142
|
+
|
143
|
+
return res;
|
144
|
+
}
|
145
|
+
|
121
146
|
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
122
147
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
123
148
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -163,38 +188,26 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
163
188
|
|
164
189
|
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
165
190
|
const int64_t n_tokens = ubatch->n_tokens;
|
166
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
167
191
|
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
168
192
|
|
169
193
|
if (cparams.embeddings && (
|
170
|
-
|
171
|
-
|
172
|
-
|
194
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
195
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
|
196
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
|
197
|
+
)) {
|
173
198
|
GGML_ASSERT(cls);
|
174
199
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
175
200
|
|
176
201
|
uint32_t * data = (uint32_t *) cls->data;
|
177
202
|
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
178
203
|
|
179
|
-
|
180
|
-
|
181
|
-
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
182
|
-
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
183
|
-
|
184
|
-
data[seq_idx] = i;
|
185
|
-
}
|
186
|
-
}
|
187
|
-
}
|
204
|
+
std::vector<int> target_pos(n_seqs_unq, -1);
|
205
|
+
std::vector<int> target_row(n_seqs_unq, -1);
|
188
206
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
uint32_t * data = (uint32_t *) cls->data;
|
194
|
-
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
195
|
-
|
196
|
-
std::vector<int> last_pos(n_seqs_unq, -1);
|
197
|
-
std::vector<int> last_row(n_seqs_unq, -1);
|
207
|
+
const bool last = (
|
208
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
209
|
+
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
210
|
+
);
|
198
211
|
|
199
212
|
for (int i = 0; i < n_tokens; ++i) {
|
200
213
|
const llama_pos pos = ubatch->pos[i];
|
@@ -203,16 +216,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
203
216
|
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
204
217
|
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
205
218
|
|
206
|
-
if (
|
207
|
-
|
208
|
-
|
219
|
+
if (
|
220
|
+
(target_pos[seq_idx] == -1) ||
|
221
|
+
( last && pos >= target_pos[seq_idx]) ||
|
222
|
+
(!last && pos < target_pos[seq_idx])
|
223
|
+
) {
|
224
|
+
target_pos[seq_idx] = pos;
|
225
|
+
target_row[seq_idx] = i;
|
209
226
|
}
|
210
227
|
}
|
211
228
|
}
|
212
229
|
|
213
230
|
for (int s = 0; s < n_seqs_unq; ++s) {
|
214
|
-
if (
|
215
|
-
data[s] =
|
231
|
+
if (target_row[s] >= 0) {
|
232
|
+
data[s] = target_row[s];
|
216
233
|
}
|
217
234
|
}
|
218
235
|
}
|
@@ -244,6 +261,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
244
261
|
}
|
245
262
|
}
|
246
263
|
|
264
|
+
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
265
|
+
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
266
|
+
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
267
|
+
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
268
|
+
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
269
|
+
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
270
|
+
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
271
|
+
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
272
|
+
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
273
|
+
|
274
|
+
LLAMA_LOG_DEBUG(" ");
|
275
|
+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
276
|
+
LLAMA_LOG_DEBUG("%2d", j);
|
277
|
+
}
|
278
|
+
LLAMA_LOG_DEBUG("\n");
|
279
|
+
|
280
|
+
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
|
281
|
+
LLAMA_LOG_DEBUG(" %2d ", i);
|
282
|
+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
283
|
+
float val = data[i * n_kv + j];
|
284
|
+
if (val == -INFINITY) {
|
285
|
+
LLAMA_LOG_DEBUG(" ∞");
|
286
|
+
} else {
|
287
|
+
LLAMA_LOG_DEBUG(" 0");
|
288
|
+
}
|
289
|
+
}
|
290
|
+
LLAMA_LOG_DEBUG("\n");
|
291
|
+
}
|
292
|
+
}
|
293
|
+
|
247
294
|
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
248
295
|
const int64_t n_kv = ubatch->n_tokens;
|
249
296
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -253,6 +300,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
253
300
|
|
254
301
|
float * data = (float *) kq_mask->data;
|
255
302
|
|
303
|
+
// [TAG_NO_CACHE_ISWA]
|
304
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
305
|
+
|
256
306
|
for (int h = 0; h < 1; ++h) {
|
257
307
|
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
258
308
|
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
@@ -263,37 +313,90 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
263
313
|
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
264
314
|
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
265
315
|
|
316
|
+
if (s0 != s1) {
|
317
|
+
continue; // skip different sequences
|
318
|
+
}
|
319
|
+
|
320
|
+
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
321
|
+
continue; // skip future tokens for causal attention
|
322
|
+
}
|
323
|
+
|
324
|
+
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
325
|
+
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
326
|
+
// continue; // skip masked tokens for SWA
|
327
|
+
//}
|
328
|
+
|
266
329
|
// TODO: reimplement this like in llama_kv_cache_unified
|
267
|
-
if (
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
f = 0.0f;
|
272
|
-
}
|
273
|
-
break;
|
330
|
+
if (hparams.use_alibi) {
|
331
|
+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
332
|
+
} else {
|
333
|
+
f = 0.0f;
|
274
334
|
}
|
275
335
|
}
|
276
|
-
|
277
336
|
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
278
337
|
}
|
279
338
|
}
|
280
339
|
}
|
340
|
+
if (debug) {
|
341
|
+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
342
|
+
}
|
281
343
|
}
|
282
344
|
|
283
|
-
void
|
284
|
-
|
285
|
-
|
286
|
-
|
345
|
+
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
346
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
347
|
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
348
|
+
|
349
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
287
350
|
}
|
288
351
|
|
289
|
-
|
290
|
-
|
291
|
-
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
292
|
-
}
|
352
|
+
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
353
|
+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
293
354
|
|
294
|
-
|
295
|
-
|
296
|
-
|
355
|
+
this->mctx = mctx;
|
356
|
+
|
357
|
+
bool res = true;
|
358
|
+
|
359
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
360
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
361
|
+
|
362
|
+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
363
|
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
364
|
+
|
365
|
+
return res;
|
366
|
+
}
|
367
|
+
|
368
|
+
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
369
|
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
370
|
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
371
|
+
|
372
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
373
|
+
|
374
|
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
375
|
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
376
|
+
|
377
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
378
|
+
}
|
379
|
+
|
380
|
+
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
381
|
+
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
|
382
|
+
|
383
|
+
this->mctx = mctx;
|
384
|
+
|
385
|
+
bool res = true;
|
386
|
+
|
387
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
388
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
389
|
+
|
390
|
+
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
391
|
+
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
392
|
+
|
393
|
+
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
394
|
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
395
|
+
|
396
|
+
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
397
|
+
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
398
|
+
|
399
|
+
return res;
|
297
400
|
}
|
298
401
|
|
299
402
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
@@ -303,7 +406,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
303
406
|
const int64_t n_tokens = ubatch->n_tokens;
|
304
407
|
|
305
408
|
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
306
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
409
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
307
410
|
|
308
411
|
float * data = (float *) cross_kq_mask->data;
|
309
412
|
|
@@ -333,27 +436,93 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
333
436
|
}
|
334
437
|
|
335
438
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
336
|
-
|
337
|
-
|
439
|
+
inp_attn->set_input(ubatch);
|
440
|
+
inp_rs->set_input(ubatch);
|
441
|
+
}
|
442
|
+
|
443
|
+
//
|
444
|
+
// llm_graph_result
|
445
|
+
//
|
446
|
+
|
447
|
+
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
|
448
|
+
reset();
|
449
|
+
|
450
|
+
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
|
451
|
+
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
|
452
|
+
}
|
453
|
+
|
454
|
+
int64_t llm_graph_result::get_max_nodes() const {
|
455
|
+
return max_nodes;
|
456
|
+
}
|
457
|
+
|
458
|
+
void llm_graph_result::reset() {
|
459
|
+
t_tokens = nullptr;
|
460
|
+
t_logits = nullptr;
|
461
|
+
t_embd = nullptr;
|
462
|
+
t_embd_pooled = nullptr;
|
463
|
+
|
464
|
+
params = {};
|
465
|
+
|
466
|
+
inputs.clear();
|
467
|
+
|
468
|
+
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
469
|
+
|
470
|
+
ggml_init_params params = {
|
471
|
+
/*.mem_size =*/ buf_compute_meta.size(),
|
472
|
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
473
|
+
/*.no_alloc =*/ true,
|
474
|
+
};
|
475
|
+
|
476
|
+
ctx_compute.reset(ggml_init(params));
|
477
|
+
|
478
|
+
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
|
479
|
+
}
|
480
|
+
|
481
|
+
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
482
|
+
for (auto & input : inputs) {
|
483
|
+
input->set_input(ubatch);
|
338
484
|
}
|
485
|
+
}
|
339
486
|
|
340
|
-
|
487
|
+
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
488
|
+
if (!this->params.allow_reuse(params)) {
|
489
|
+
if (debug > 1) {
|
490
|
+
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
|
491
|
+
}
|
341
492
|
|
342
|
-
|
343
|
-
|
344
|
-
int32_t * data = (int32_t *) s_copy->data;
|
493
|
+
return false;
|
494
|
+
}
|
345
495
|
|
346
|
-
|
347
|
-
|
348
|
-
|
496
|
+
if (debug > 1) {
|
497
|
+
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
|
498
|
+
}
|
499
|
+
|
500
|
+
bool res = true;
|
501
|
+
|
502
|
+
for (auto & input : inputs) {
|
503
|
+
const bool cur = input->can_reuse(params);
|
504
|
+
|
505
|
+
if (debug > 1) {
|
506
|
+
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
|
349
507
|
}
|
508
|
+
|
509
|
+
res = res && cur;
|
510
|
+
}
|
511
|
+
|
512
|
+
if (debug > 0) {
|
513
|
+
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
|
350
514
|
}
|
515
|
+
|
516
|
+
return res;
|
517
|
+
}
|
518
|
+
|
519
|
+
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
|
520
|
+
inputs.emplace_back(std::move(input));
|
521
|
+
return inputs.back().get();
|
351
522
|
}
|
352
523
|
|
353
|
-
void
|
354
|
-
|
355
|
-
float f_one = 1.0f;
|
356
|
-
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
524
|
+
void llm_graph_result::set_params(const llm_graph_params & params) {
|
525
|
+
this->params = params;
|
357
526
|
}
|
358
527
|
|
359
528
|
//
|
@@ -390,7 +559,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
390
559
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
391
560
|
pooling_type (cparams.pooling_type),
|
392
561
|
rope_type (hparams.rope_type),
|
393
|
-
ctx0 (params.ctx),
|
394
562
|
sched (params.sched),
|
395
563
|
backend_cpu (params.backend_cpu),
|
396
564
|
cvec (params.cvec),
|
@@ -398,7 +566,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
398
566
|
mctx (params.mctx),
|
399
567
|
cross (params.cross),
|
400
568
|
cb_func (params.cb),
|
401
|
-
res (
|
569
|
+
res (params.res),
|
570
|
+
ctx0 (res->get_ctx()),
|
571
|
+
gf (res->get_gf()) {
|
572
|
+
res->set_params(params);
|
402
573
|
}
|
403
574
|
|
404
575
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
@@ -613,6 +784,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
613
784
|
cur = ggml_reglu(ctx0, cur);
|
614
785
|
cb(cur, "ffn_reglu", il);
|
615
786
|
} break;
|
787
|
+
default:
|
788
|
+
GGML_ABORT("fatal error");
|
616
789
|
}
|
617
790
|
|
618
791
|
if (gate && type_gate == LLM_FFN_PAR) {
|
@@ -622,8 +795,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
622
795
|
|
623
796
|
if (down) {
|
624
797
|
cur = build_lora_mm(down, cur);
|
625
|
-
if (arch == LLM_ARCH_GLM4) {
|
626
|
-
// GLM4
|
798
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
799
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
627
800
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
628
801
|
}
|
629
802
|
}
|
@@ -658,13 +831,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
658
831
|
bool scale_w,
|
659
832
|
float w_scale,
|
660
833
|
llama_expert_gating_func_type gating_op,
|
661
|
-
int il
|
834
|
+
int il,
|
835
|
+
ggml_tensor * probs_in) const {
|
836
|
+
return build_moe_ffn(
|
837
|
+
cur,
|
838
|
+
gate_inp, /* gate_inp_b */ nullptr,
|
839
|
+
up_exps, /* up_exps_b */ nullptr,
|
840
|
+
gate_exps, /* gate_exps_b */ nullptr,
|
841
|
+
down_exps, /* down_exps_b */ nullptr,
|
842
|
+
exp_probs_b,
|
843
|
+
n_expert,
|
844
|
+
n_expert_used,
|
845
|
+
type_op,
|
846
|
+
norm_w,
|
847
|
+
scale_w,
|
848
|
+
w_scale,
|
849
|
+
gating_op,
|
850
|
+
il,
|
851
|
+
probs_in
|
852
|
+
);
|
853
|
+
}
|
854
|
+
|
855
|
+
ggml_tensor * llm_graph_context::build_moe_ffn(
|
856
|
+
ggml_tensor * cur,
|
857
|
+
ggml_tensor * gate_inp,
|
858
|
+
ggml_tensor * gate_inp_b,
|
859
|
+
ggml_tensor * up_exps,
|
860
|
+
ggml_tensor * up_exps_b,
|
861
|
+
ggml_tensor * gate_exps,
|
862
|
+
ggml_tensor * gate_exps_b,
|
863
|
+
ggml_tensor * down_exps,
|
864
|
+
ggml_tensor * down_exps_b,
|
865
|
+
ggml_tensor * exp_probs_b,
|
866
|
+
int64_t n_expert,
|
867
|
+
int64_t n_expert_used,
|
868
|
+
llm_ffn_op_type type_op,
|
869
|
+
bool norm_w,
|
870
|
+
bool scale_w,
|
871
|
+
float w_scale,
|
872
|
+
llama_expert_gating_func_type gating_op,
|
873
|
+
int il,
|
874
|
+
ggml_tensor * probs_in) const {
|
662
875
|
const int64_t n_embd = cur->ne[0];
|
663
876
|
const int64_t n_tokens = cur->ne[1];
|
664
877
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
665
878
|
|
666
|
-
ggml_tensor * logits =
|
667
|
-
|
879
|
+
ggml_tensor * logits = nullptr;
|
880
|
+
|
881
|
+
if (probs_in == nullptr) {
|
882
|
+
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
883
|
+
cb(logits, "ffn_moe_logits", il);
|
884
|
+
} else {
|
885
|
+
logits = probs_in;
|
886
|
+
}
|
887
|
+
|
888
|
+
if (gate_inp_b) {
|
889
|
+
logits = ggml_add(ctx0, logits, gate_inp_b);
|
890
|
+
cb(logits, "ffn_moe_logits_biased", il);
|
891
|
+
}
|
668
892
|
|
669
893
|
ggml_tensor * probs = nullptr;
|
670
894
|
switch (gating_op) {
|
@@ -676,6 +900,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
676
900
|
{
|
677
901
|
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
678
902
|
} break;
|
903
|
+
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
|
904
|
+
{
|
905
|
+
probs = logits; // [n_expert, n_tokens]
|
906
|
+
} break;
|
679
907
|
default:
|
680
908
|
GGML_ABORT("fatal error");
|
681
909
|
}
|
@@ -695,15 +923,36 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
695
923
|
selection_probs = logits;
|
696
924
|
}
|
697
925
|
|
926
|
+
if (arch == LLM_ARCH_GROVEMOE) {
|
927
|
+
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
928
|
+
cb(selection_probs, "ffn_moe_probs_biased", il);
|
929
|
+
}
|
930
|
+
|
698
931
|
// select experts
|
699
932
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
700
933
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
701
934
|
cb(selected_experts, "ffn_moe_topk", il);
|
702
935
|
|
703
|
-
|
704
|
-
|
936
|
+
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
|
937
|
+
// TODO: Use scalar div instead when/if implemented
|
938
|
+
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
|
939
|
+
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
|
940
|
+
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
|
941
|
+
} else {
|
942
|
+
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
|
943
|
+
}
|
944
|
+
|
945
|
+
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
|
705
946
|
cb(weights, "ffn_moe_weights", il);
|
706
947
|
|
948
|
+
|
949
|
+
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
950
|
+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
951
|
+
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
|
952
|
+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
953
|
+
cb(weights, "ffn_moe_weights_softmax", il);
|
954
|
+
}
|
955
|
+
|
707
956
|
if (norm_w) {
|
708
957
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
709
958
|
|
@@ -720,6 +969,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
720
969
|
cb(weights, "ffn_moe_weights_scaled", il);
|
721
970
|
}
|
722
971
|
|
972
|
+
//call early so that topk-moe can be used
|
973
|
+
ggml_build_forward_expand(gf, weights);
|
974
|
+
|
723
975
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
724
976
|
|
725
977
|
if (weight_before_ffn) {
|
@@ -732,6 +984,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
732
984
|
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
733
985
|
cb(up, "ffn_moe_up", il);
|
734
986
|
|
987
|
+
if (up_exps_b) {
|
988
|
+
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
989
|
+
cb(up, "ffn_moe_up_biased", il);
|
990
|
+
}
|
991
|
+
|
735
992
|
ggml_tensor * experts = nullptr;
|
736
993
|
if (gate_exps) {
|
737
994
|
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
@@ -740,6 +997,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
740
997
|
cur = up;
|
741
998
|
}
|
742
999
|
|
1000
|
+
if (gate_exps_b) {
|
1001
|
+
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
1002
|
+
cb(cur, "ffn_moe_gate_biased", il);
|
1003
|
+
}
|
1004
|
+
|
743
1005
|
switch (type_op) {
|
744
1006
|
case LLM_FFN_SILU:
|
745
1007
|
if (gate_exps) {
|
@@ -757,6 +1019,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
757
1019
|
cur = ggml_gelu(ctx0, cur);
|
758
1020
|
cb(cur, "ffn_moe_gelu", il);
|
759
1021
|
} break;
|
1022
|
+
case LLM_FFN_SWIGLU_OAI_MOE:
|
1023
|
+
{
|
1024
|
+
// TODO: move to hparams?
|
1025
|
+
constexpr float alpha = 1.702f;
|
1026
|
+
constexpr float limit = 7.0f;
|
1027
|
+
cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
|
1028
|
+
cb(cur, "ffn_moe_swiglu_oai", il);
|
1029
|
+
} break;
|
1030
|
+
case LLM_FFN_RELU:
|
1031
|
+
if (gate_exps) {
|
1032
|
+
cur = ggml_reglu_split(ctx0, cur, up);
|
1033
|
+
cb(cur, "ffn_moe_reglu", il);
|
1034
|
+
} else {
|
1035
|
+
cur = ggml_relu(ctx0, cur);
|
1036
|
+
cb(cur, "ffn_moe_relu", il);
|
1037
|
+
} break;
|
760
1038
|
default:
|
761
1039
|
GGML_ABORT("fatal error");
|
762
1040
|
}
|
@@ -764,25 +1042,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
764
1042
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
765
1043
|
cb(experts, "ffn_moe_down", il);
|
766
1044
|
|
1045
|
+
if (down_exps_b) {
|
1046
|
+
experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
|
1047
|
+
cb(experts, "ffn_moe_down_biased", il);
|
1048
|
+
}
|
1049
|
+
|
767
1050
|
if (!weight_before_ffn) {
|
768
1051
|
experts = ggml_mul(ctx0, experts, weights);
|
769
1052
|
cb(cur, "ffn_moe_weighted", il);
|
770
1053
|
}
|
771
1054
|
|
1055
|
+
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
|
1056
|
+
|
1057
|
+
assert(n_expert_used > 0);
|
1058
|
+
|
1059
|
+
// order the views before the adds
|
1060
|
+
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
|
1061
|
+
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
1062
|
+
|
1063
|
+
ggml_build_forward_expand(gf, cur_experts[i]);
|
1064
|
+
}
|
1065
|
+
|
772
1066
|
// aggregate experts
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
1067
|
+
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
|
1068
|
+
// to avoid potentially a large number of add nodes during warmup
|
1069
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
|
1070
|
+
ggml_tensor * moe_out = cur_experts[0];
|
777
1071
|
|
778
|
-
|
779
|
-
|
780
|
-
} else {
|
781
|
-
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
782
|
-
}
|
1072
|
+
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
|
1073
|
+
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
|
783
1074
|
}
|
784
1075
|
|
785
|
-
if (n_expert_used == 1) {
|
1076
|
+
if (hparams.n_expert_used == 1) {
|
786
1077
|
// avoid returning a non-contiguous tensor
|
787
1078
|
moe_out = ggml_cont(ctx0, moe_out);
|
788
1079
|
}
|
@@ -906,7 +1197,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
906
1197
|
}
|
907
1198
|
|
908
1199
|
ggml_tensor * llm_graph_context::build_inp_cls() const {
|
909
|
-
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
|
1200
|
+
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
|
910
1201
|
|
911
1202
|
auto & cur = inp->cls;
|
912
1203
|
|
@@ -956,7 +1247,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
956
1247
|
}
|
957
1248
|
|
958
1249
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
959
|
-
const auto * mctx_cur = static_cast<const
|
1250
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
960
1251
|
|
961
1252
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
962
1253
|
|
@@ -987,51 +1278,28 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
987
1278
|
return pos_bias;
|
988
1279
|
}
|
989
1280
|
|
990
|
-
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
991
|
-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
992
|
-
|
993
|
-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
994
|
-
|
995
|
-
{
|
996
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
997
|
-
|
998
|
-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
999
|
-
|
1000
|
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
1001
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1002
|
-
ggml_set_input(inp->self_kq_mask);
|
1003
|
-
|
1004
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
1005
|
-
}
|
1006
|
-
|
1007
|
-
{
|
1008
|
-
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
1009
|
-
|
1010
|
-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
1011
|
-
ggml_set_input(inp->s_copy);
|
1012
|
-
}
|
1013
|
-
|
1014
|
-
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
1015
|
-
}
|
1016
|
-
|
1017
1281
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
1018
|
-
ggml_cgraph * gf,
|
1019
1282
|
ggml_tensor * q,
|
1020
1283
|
ggml_tensor * k,
|
1021
1284
|
ggml_tensor * v,
|
1022
1285
|
ggml_tensor * kq_b,
|
1023
1286
|
ggml_tensor * kq_mask,
|
1287
|
+
ggml_tensor * sinks,
|
1024
1288
|
ggml_tensor * v_mla,
|
1025
|
-
|
1289
|
+
float kq_scale,
|
1290
|
+
int il) const {
|
1026
1291
|
const bool v_trans = v->nb[1] > v->nb[2];
|
1027
1292
|
|
1293
|
+
// split the batch into streams if needed
|
1294
|
+
const auto n_stream = k->ne[3];
|
1295
|
+
|
1296
|
+
q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
|
1297
|
+
|
1028
1298
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
1029
1299
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
1030
1300
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
1031
1301
|
|
1032
|
-
const auto
|
1033
|
-
const auto n_head = q->ne[2];
|
1034
|
-
const auto n_kv = k->ne[1];
|
1302
|
+
const auto n_kv = k->ne[1];
|
1035
1303
|
|
1036
1304
|
ggml_tensor * cur;
|
1037
1305
|
|
@@ -1054,8 +1322,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1054
1322
|
|
1055
1323
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
1056
1324
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
1325
|
+
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
1057
1326
|
|
1058
|
-
|
1327
|
+
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
1328
|
+
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
1059
1329
|
|
1060
1330
|
if (v_mla) {
|
1061
1331
|
#if 0
|
@@ -1068,14 +1338,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1068
1338
|
// The permutations are noops and only change how the tensor data is interpreted.
|
1069
1339
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1070
1340
|
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
1341
|
+
cb(cur, "fattn_mla", il);
|
1071
1342
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1072
1343
|
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
1073
1344
|
#endif
|
1074
1345
|
}
|
1075
1346
|
|
1076
|
-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*
|
1347
|
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
1077
1348
|
} else {
|
1078
1349
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
1350
|
+
cb(kq, "kq", il);
|
1079
1351
|
|
1080
1352
|
// note: this op tends to require high floating point range
|
1081
1353
|
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
@@ -1083,42 +1355,54 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1083
1355
|
|
1084
1356
|
if (arch == LLM_ARCH_GROK) {
|
1085
1357
|
// need to do the following:
|
1086
|
-
// multiply by
|
1358
|
+
// multiply by attn_output_multiplier
|
1087
1359
|
// and then :
|
1088
1360
|
// kq = 30 * tanh(kq / 30)
|
1089
1361
|
// before the softmax below
|
1090
1362
|
|
1091
|
-
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq,
|
1092
|
-
kq
|
1363
|
+
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
|
1364
|
+
cb(kq, "kq_tanh", il);
|
1365
|
+
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
1366
|
+
cb(kq, "kq_scaled", il);
|
1093
1367
|
}
|
1094
1368
|
|
1095
1369
|
if (hparams.attn_soft_cap) {
|
1096
1370
|
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
1371
|
+
cb(kq, "kq_scaled_1", il);
|
1097
1372
|
kq = ggml_tanh (ctx0, kq);
|
1373
|
+
cb(kq, "kq_tanh", il);
|
1098
1374
|
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
1375
|
+
cb(kq, "kq_scaled_2", il);
|
1099
1376
|
}
|
1100
1377
|
|
1101
1378
|
if (kq_b) {
|
1102
1379
|
kq = ggml_add(ctx0, kq, kq_b);
|
1380
|
+
cb(kq, "kq_plus_kq_b", il);
|
1103
1381
|
}
|
1104
1382
|
|
1105
1383
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
1384
|
+
ggml_soft_max_add_sinks(kq, sinks);
|
1385
|
+
cb(kq, "kq_soft_max", il);
|
1106
1386
|
|
1107
1387
|
if (!v_trans) {
|
1108
1388
|
// note: avoid this branch
|
1109
1389
|
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
1390
|
+
cb(v, "v_cont", il);
|
1110
1391
|
}
|
1111
1392
|
|
1112
1393
|
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
1394
|
+
cb(kqv, "kqv", il);
|
1113
1395
|
|
1114
1396
|
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
1115
1397
|
if (v_mla) {
|
1116
1398
|
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
1399
|
+
cb(kqv, "kqv_mla", il);
|
1117
1400
|
}
|
1118
1401
|
|
1119
1402
|
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
1120
1403
|
|
1121
|
-
|
1404
|
+
// recombine streams
|
1405
|
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
1122
1406
|
|
1123
1407
|
if (!cparams.offload_kqv) {
|
1124
1408
|
// all nodes between the KV store and the attention output are run on the CPU
|
@@ -1135,8 +1419,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
1135
1419
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
1136
1420
|
|
1137
1421
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
1138
|
-
inp->kq_mask =
|
1139
|
-
//cb(inp_kq_mask, "KQ_mask", -1);
|
1422
|
+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
1140
1423
|
ggml_set_input(inp->kq_mask);
|
1141
1424
|
|
1142
1425
|
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
@@ -1146,13 +1429,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
1146
1429
|
|
1147
1430
|
ggml_tensor * llm_graph_context::build_attn(
|
1148
1431
|
llm_graph_input_attn_no_cache * inp,
|
1149
|
-
ggml_cgraph * gf,
|
1150
1432
|
ggml_tensor * wo,
|
1151
1433
|
ggml_tensor * wo_b,
|
1152
1434
|
ggml_tensor * q_cur,
|
1153
1435
|
ggml_tensor * k_cur,
|
1154
1436
|
ggml_tensor * v_cur,
|
1155
1437
|
ggml_tensor * kq_b,
|
1438
|
+
ggml_tensor * sinks,
|
1156
1439
|
ggml_tensor * v_mla,
|
1157
1440
|
float kq_scale,
|
1158
1441
|
int il) const {
|
@@ -1166,11 +1449,16 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1166
1449
|
|
1167
1450
|
const auto & kq_mask = inp->get_kq_mask();
|
1168
1451
|
|
1452
|
+
// [TAG_NO_CACHE_PAD]
|
1453
|
+
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
1454
|
+
// but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
|
1455
|
+
//assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
|
1456
|
+
|
1169
1457
|
ggml_tensor * q = q_cur;
|
1170
1458
|
ggml_tensor * k = k_cur;
|
1171
1459
|
ggml_tensor * v = v_cur;
|
1172
1460
|
|
1173
|
-
ggml_tensor * cur = build_attn_mha(
|
1461
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1174
1462
|
cb(cur, "kqv_out", il);
|
1175
1463
|
|
1176
1464
|
if (wo) {
|
@@ -1188,35 +1476,51 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1188
1476
|
return cur;
|
1189
1477
|
}
|
1190
1478
|
|
1191
|
-
|
1192
|
-
|
1479
|
+
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
1480
|
+
ggml_context * ctx0,
|
1481
|
+
const llama_ubatch & ubatch,
|
1482
|
+
const llama_hparams & hparams,
|
1483
|
+
const llama_cparams & cparams,
|
1484
|
+
const llama_kv_cache_context * mctx_cur) {
|
1193
1485
|
|
1194
|
-
auto inp = std::make_unique<
|
1486
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
|
1195
1487
|
|
1196
1488
|
{
|
1197
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use
|
1489
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
1490
|
+
|
1491
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
1492
|
+
const auto n_tokens = ubatch.n_tokens;
|
1493
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
1198
1494
|
|
1199
|
-
|
1495
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
1496
|
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
1200
1497
|
|
1201
|
-
inp->self_kq_mask =
|
1202
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1498
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
1203
1499
|
ggml_set_input(inp->self_kq_mask);
|
1204
1500
|
|
1205
1501
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
1206
1502
|
}
|
1207
1503
|
|
1208
|
-
return
|
1504
|
+
return inp;
|
1505
|
+
}
|
1506
|
+
|
1507
|
+
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
|
1508
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
1509
|
+
|
1510
|
+
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
1511
|
+
|
1512
|
+
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
|
1209
1513
|
}
|
1210
1514
|
|
1211
1515
|
ggml_tensor * llm_graph_context::build_attn(
|
1212
|
-
|
1213
|
-
ggml_cgraph * gf,
|
1516
|
+
llm_graph_input_attn_kv * inp,
|
1214
1517
|
ggml_tensor * wo,
|
1215
1518
|
ggml_tensor * wo_b,
|
1216
1519
|
ggml_tensor * q_cur,
|
1217
1520
|
ggml_tensor * k_cur,
|
1218
1521
|
ggml_tensor * v_cur,
|
1219
1522
|
ggml_tensor * kq_b,
|
1523
|
+
ggml_tensor * sinks,
|
1220
1524
|
ggml_tensor * v_mla,
|
1221
1525
|
float kq_scale,
|
1222
1526
|
int il) const {
|
@@ -1226,12 +1530,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1226
1530
|
ggml_build_forward_expand(gf, k_cur);
|
1227
1531
|
ggml_build_forward_expand(gf, v_cur);
|
1228
1532
|
|
1229
|
-
const auto * mctx_cur =
|
1533
|
+
const auto * mctx_cur = inp->mctx;
|
1230
1534
|
|
1231
1535
|
// store to KV cache
|
1232
1536
|
{
|
1233
|
-
|
1234
|
-
|
1537
|
+
const auto & k_idxs = inp->get_k_idxs();
|
1538
|
+
const auto & v_idxs = inp->get_v_idxs();
|
1539
|
+
|
1540
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
1541
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
1235
1542
|
}
|
1236
1543
|
|
1237
1544
|
const auto & kq_mask = inp->get_kq_mask();
|
@@ -1240,13 +1547,13 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1240
1547
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
1241
1548
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
1242
1549
|
|
1243
|
-
ggml_tensor * cur = build_attn_mha(
|
1550
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1244
1551
|
cb(cur, "kqv_out", il);
|
1245
1552
|
|
1246
1553
|
if (wo) {
|
1247
1554
|
cur = build_lora_mm(wo, cur);
|
1248
|
-
if (arch == LLM_ARCH_GLM4) {
|
1249
|
-
// GLM4
|
1555
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
1556
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
1250
1557
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
1251
1558
|
}
|
1252
1559
|
}
|
@@ -1259,14 +1566,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1259
1566
|
}
|
1260
1567
|
|
1261
1568
|
ggml_tensor * llm_graph_context::build_attn(
|
1262
|
-
|
1263
|
-
ggml_cgraph * gf,
|
1569
|
+
llm_graph_input_attn_kv_iswa * inp,
|
1264
1570
|
ggml_tensor * wo,
|
1265
1571
|
ggml_tensor * wo_b,
|
1266
1572
|
ggml_tensor * q_cur,
|
1267
1573
|
ggml_tensor * k_cur,
|
1268
1574
|
ggml_tensor * v_cur,
|
1269
1575
|
ggml_tensor * kq_b,
|
1576
|
+
ggml_tensor * sinks,
|
1270
1577
|
ggml_tensor * v_mla,
|
1271
1578
|
float kq_scale,
|
1272
1579
|
int il) const {
|
@@ -1282,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1282
1589
|
ggml_build_forward_expand(gf, v_cur);
|
1283
1590
|
}
|
1284
1591
|
|
1285
|
-
const auto * mctx_iswa =
|
1592
|
+
const auto * mctx_iswa = inp->mctx;
|
1286
1593
|
|
1287
1594
|
const bool is_swa = hparams.is_swa(il);
|
1288
1595
|
|
@@ -1290,11 +1597,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1290
1597
|
|
1291
1598
|
// optionally store to KV cache
|
1292
1599
|
if (k_cur) {
|
1293
|
-
|
1600
|
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
1601
|
+
|
1602
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
1294
1603
|
}
|
1295
1604
|
|
1296
1605
|
if (v_cur) {
|
1297
|
-
|
1606
|
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
1607
|
+
|
1608
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
1298
1609
|
}
|
1299
1610
|
|
1300
1611
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
@@ -1303,7 +1614,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1303
1614
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
1304
1615
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
1305
1616
|
|
1306
|
-
ggml_tensor * cur = build_attn_mha(
|
1617
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1307
1618
|
cb(cur, "kqv_out", il);
|
1308
1619
|
|
1309
1620
|
if (wo) {
|
@@ -1326,7 +1637,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
1326
1637
|
|
1327
1638
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
1328
1639
|
|
1329
|
-
inp->cross_kq_mask =
|
1640
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
1330
1641
|
ggml_set_input(inp->cross_kq_mask);
|
1331
1642
|
|
1332
1643
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
@@ -1336,13 +1647,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
1336
1647
|
|
1337
1648
|
ggml_tensor * llm_graph_context::build_attn(
|
1338
1649
|
llm_graph_input_attn_cross * inp,
|
1339
|
-
ggml_cgraph * gf,
|
1340
1650
|
ggml_tensor * wo,
|
1341
1651
|
ggml_tensor * wo_b,
|
1342
1652
|
ggml_tensor * q_cur,
|
1343
1653
|
ggml_tensor * k_cur,
|
1344
1654
|
ggml_tensor * v_cur,
|
1345
1655
|
ggml_tensor * kq_b,
|
1656
|
+
ggml_tensor * sinks,
|
1346
1657
|
ggml_tensor * v_mla,
|
1347
1658
|
float kq_scale,
|
1348
1659
|
int il) const {
|
@@ -1358,7 +1669,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1358
1669
|
ggml_tensor * k = k_cur;
|
1359
1670
|
ggml_tensor * v = v_cur;
|
1360
1671
|
|
1361
|
-
ggml_tensor * cur = build_attn_mha(
|
1672
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
1362
1673
|
cb(cur, "kqv_out", il);
|
1363
1674
|
|
1364
1675
|
if (wo) {
|
@@ -1376,171 +1687,124 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
1376
1687
|
return cur;
|
1377
1688
|
}
|
1378
1689
|
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
ggml_tensor * q_cur,
|
1385
|
-
ggml_tensor * k_cur,
|
1386
|
-
ggml_tensor * v_cur,
|
1387
|
-
ggml_tensor * kq_b,
|
1388
|
-
ggml_tensor * v_mla,
|
1389
|
-
float kq_scale,
|
1390
|
-
int il) const {
|
1391
|
-
// these nodes are added to the graph together so that they are not reordered
|
1392
|
-
// by doing so, the number of splits in the graph is reduced
|
1393
|
-
ggml_build_forward_expand(gf, q_cur);
|
1394
|
-
ggml_build_forward_expand(gf, k_cur);
|
1395
|
-
ggml_build_forward_expand(gf, v_cur);
|
1396
|
-
|
1397
|
-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
1398
|
-
|
1399
|
-
// store to KV cache
|
1400
|
-
{
|
1401
|
-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
1402
|
-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
1403
|
-
}
|
1404
|
-
|
1405
|
-
const auto & kq_mask = inp->get_kq_mask();
|
1406
|
-
|
1407
|
-
ggml_tensor * q = q_cur;
|
1408
|
-
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
1409
|
-
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
1410
|
-
|
1411
|
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1412
|
-
cb(cur, "kqv_out", il);
|
1413
|
-
|
1414
|
-
if (wo) {
|
1415
|
-
cur = build_lora_mm(wo, cur);
|
1416
|
-
if (arch == LLM_ARCH_GLM4) {
|
1417
|
-
// GLM4 seems to have numerical issues with half-precision accumulators
|
1418
|
-
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
1419
|
-
}
|
1420
|
-
}
|
1421
|
-
|
1422
|
-
if (wo_b) {
|
1423
|
-
cur = ggml_add(ctx0, cur, wo_b);
|
1424
|
-
}
|
1425
|
-
|
1426
|
-
return cur;
|
1427
|
-
}
|
1690
|
+
// TODO: maybe separate the inner implementation into a separate function
|
1691
|
+
// like with the non-sliding window equivalent
|
1692
|
+
// once sliding-window hybrid caches are a thing.
|
1693
|
+
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
|
1694
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
|
1428
1695
|
|
1429
|
-
|
1430
|
-
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
1696
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
1431
1697
|
|
1432
|
-
auto
|
1698
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
1433
1699
|
|
1434
1700
|
{
|
1435
1701
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
1436
1702
|
|
1437
|
-
inp->
|
1438
|
-
|
1703
|
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
1704
|
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
1705
|
+
|
1706
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
1439
1707
|
ggml_set_input(inp->self_kq_mask);
|
1440
1708
|
|
1441
1709
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
1442
1710
|
}
|
1443
1711
|
|
1444
1712
|
{
|
1445
|
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use
|
1713
|
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
1446
1714
|
|
1447
1715
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
1448
1716
|
|
1449
|
-
inp->
|
1450
|
-
|
1717
|
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
1718
|
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
1719
|
+
|
1720
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
1451
1721
|
ggml_set_input(inp->self_kq_mask_swa);
|
1452
1722
|
|
1453
1723
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
1454
1724
|
}
|
1455
1725
|
|
1456
|
-
return (
|
1726
|
+
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
1457
1727
|
}
|
1458
1728
|
|
1459
1729
|
ggml_tensor * llm_graph_context::build_rs(
|
1460
|
-
ggml_cgraph * gf,
|
1461
1730
|
ggml_tensor * s,
|
1462
|
-
ggml_tensor *
|
1731
|
+
ggml_tensor * state_copy_main,
|
1732
|
+
ggml_tensor * state_copy_extra,
|
1463
1733
|
int32_t state_size,
|
1464
1734
|
int32_t n_seqs,
|
1465
|
-
uint32_t
|
1466
|
-
uint32_t
|
1467
|
-
uint32_t
|
1735
|
+
uint32_t n_rs,
|
1736
|
+
uint32_t rs_head,
|
1737
|
+
uint32_t rs_size,
|
1468
1738
|
int32_t rs_zero,
|
1469
|
-
|
1739
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
1470
1740
|
|
1471
|
-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size,
|
1741
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
|
1472
1742
|
|
1473
1743
|
// Clear a single state which will then be copied to the other cleared states.
|
1474
1744
|
// Note that this is a no-op when the view is zero-sized.
|
1475
1745
|
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
1476
1746
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
1477
1747
|
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
// {state_size, kv_size} -> {state_size, n_seqs}
|
1484
|
-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
1485
|
-
ggml_build_forward_expand(gf, output_states);
|
1486
|
-
} else {
|
1487
|
-
// FIXME: make the gathering operation happen before the copy below
|
1488
|
-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
1489
|
-
output_states = states;
|
1490
|
-
}
|
1748
|
+
// copy states
|
1749
|
+
// NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
|
1750
|
+
// {state_size, rs_size} -> {state_size, n_seqs}
|
1751
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
|
1752
|
+
ggml_build_forward_expand(gf, output_states);
|
1491
1753
|
|
1492
|
-
// copy extra states which won't be changed further (between n_seqs and
|
1493
|
-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states,
|
1754
|
+
// copy extra states which won't be changed further (between n_seqs and n_rs)
|
1755
|
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
|
1494
1756
|
ggml_build_forward_expand(gf,
|
1495
1757
|
ggml_cpy(ctx0,
|
1496
1758
|
states_extra,
|
1497
|
-
ggml_view_1d(ctx0, s, state_size*(
|
1759
|
+
ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
|
1498
1760
|
|
1499
1761
|
return output_states;
|
1500
1762
|
}
|
1501
1763
|
|
1502
|
-
llm_graph_input_rs
|
1503
|
-
|
1764
|
+
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
1765
|
+
ggml_context * ctx0,
|
1766
|
+
const llama_ubatch & ubatch,
|
1767
|
+
const llama_memory_recurrent_context * mctx_cur) {
|
1504
1768
|
|
1505
1769
|
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
1506
1770
|
|
1507
|
-
const
|
1771
|
+
const int64_t n_rs = mctx_cur->get_n_rs();
|
1772
|
+
const int64_t n_seqs = ubatch.n_seqs;
|
1508
1773
|
|
1509
1774
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
1510
1775
|
ggml_set_input(inp->s_copy);
|
1511
1776
|
|
1512
|
-
|
1777
|
+
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
1778
|
+
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
1779
|
+
|
1780
|
+
return inp;
|
1513
1781
|
}
|
1514
1782
|
|
1515
|
-
|
1516
|
-
llm_graph_input_rs * inp,
|
1517
|
-
ggml_cgraph * gf,
|
1518
|
-
ggml_tensor * s,
|
1519
|
-
int32_t state_size,
|
1520
|
-
int32_t n_seqs,
|
1521
|
-
bool avoid_copies) const {
|
1783
|
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
1522
1784
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
1523
1785
|
|
1524
|
-
|
1786
|
+
auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
|
1787
|
+
|
1788
|
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
1525
1789
|
}
|
1526
1790
|
|
1527
1791
|
ggml_tensor * llm_graph_context::build_rs(
|
1528
|
-
|
1529
|
-
ggml_cgraph * gf,
|
1792
|
+
llm_graph_input_rs * inp,
|
1530
1793
|
ggml_tensor * s,
|
1531
1794
|
int32_t state_size,
|
1532
1795
|
int32_t n_seqs,
|
1533
|
-
|
1534
|
-
const auto *
|
1796
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
1797
|
+
const auto * kv_state = inp->mctx;
|
1535
1798
|
|
1536
|
-
return build_rs(
|
1799
|
+
return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
|
1800
|
+
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
|
1801
|
+
get_state_rows);
|
1537
1802
|
}
|
1538
1803
|
|
1539
1804
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
1540
1805
|
llm_graph_input_rs * inp,
|
1541
|
-
ggml_cgraph * gf,
|
1542
1806
|
const llama_ubatch & ubatch,
|
1543
|
-
|
1807
|
+
int il) const {
|
1544
1808
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
1545
1809
|
|
1546
1810
|
const auto token_shift_count = hparams.token_shift_count;
|
@@ -1550,7 +1814,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
1550
1814
|
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
1551
1815
|
|
1552
1816
|
ggml_tensor * token_shift = build_rs(
|
1553
|
-
inp,
|
1817
|
+
inp, token_shift_all,
|
1554
1818
|
hparams.n_embd_r(), n_seqs);
|
1555
1819
|
|
1556
1820
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
@@ -1578,8 +1842,18 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
1578
1842
|
);
|
1579
1843
|
}
|
1580
1844
|
|
1845
|
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
1846
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
1847
|
+
|
1848
|
+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
1849
|
+
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
1850
|
+
|
1851
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
1852
|
+
|
1853
|
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
1854
|
+
}
|
1855
|
+
|
1581
1856
|
void llm_graph_context::build_pooling(
|
1582
|
-
ggml_cgraph * gf,
|
1583
1857
|
ggml_tensor * cls,
|
1584
1858
|
ggml_tensor * cls_b,
|
1585
1859
|
ggml_tensor * cls_out,
|
@@ -1623,34 +1897,32 @@ void llm_graph_context::build_pooling(
|
|
1623
1897
|
case LLAMA_POOLING_TYPE_RANK:
|
1624
1898
|
{
|
1625
1899
|
ggml_tensor * inp_cls = build_inp_cls();
|
1626
|
-
|
1900
|
+
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
1627
1901
|
|
1902
|
+
// classification head
|
1903
|
+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
1628
1904
|
if (cls) {
|
1629
|
-
|
1630
|
-
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
1631
|
-
cur = ggml_mul_mat(ctx0, cls, inp);
|
1905
|
+
cur = ggml_mul_mat(ctx0, cls, cur);
|
1632
1906
|
if (cls_b) {
|
1633
1907
|
cur = ggml_add(ctx0, cur, cls_b);
|
1634
1908
|
}
|
1635
1909
|
cur = ggml_tanh(ctx0, cur);
|
1910
|
+
}
|
1636
1911
|
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1640
|
-
|
1641
|
-
|
1642
|
-
|
1643
|
-
}
|
1644
|
-
}
|
1645
|
-
} else if (cls_out) {
|
1646
|
-
// Single layer classification head (direct projection)
|
1647
|
-
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
1648
|
-
cur = ggml_mul_mat(ctx0, cls_out, inp);
|
1912
|
+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
1913
|
+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
1914
|
+
// Single layer classification head (direct projection)
|
1915
|
+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
1916
|
+
if (cls_out) {
|
1917
|
+
cur = ggml_mul_mat(ctx0, cls_out, cur);
|
1649
1918
|
if (cls_out_b) {
|
1650
1919
|
cur = ggml_add(ctx0, cur, cls_out_b);
|
1651
1920
|
}
|
1652
|
-
}
|
1653
|
-
|
1921
|
+
}
|
1922
|
+
|
1923
|
+
// softmax for qwen3 reranker
|
1924
|
+
if (arch == LLM_ARCH_QWEN3) {
|
1925
|
+
cur = ggml_soft_max(ctx0, cur);
|
1654
1926
|
}
|
1655
1927
|
} break;
|
1656
1928
|
default:
|